api.py 5.48 KB
Newer Older
文靖昊 committed
1
import json
tinywell committed
2
import os
3 4
import sys
sys.path.append('../')
tinywell committed
5 6
import argparse
from typing import Optional
陈正乐 committed
7 8
from fastapi import FastAPI, Header
from fastapi.middleware.cors import CORSMiddleware
tinywell committed
9

陈正乐 committed
10
import uvicorn
tinywell committed
11 12


13
from src.server.agent_rate import new_rate_agent, RateAgentV3
文靖昊 committed
14
from src.server.classify import new_router_llm
文靖昊 committed
15
from src.server.rewrite import new_re_rewriter_llm
16
from src.controller.request import GeoAgentRateRequest
tinywell committed
17

18
from langchain_openai import ChatOpenAI
陈正乐 committed
19

tinywell committed
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
# 默认配置
DEFAULT_CONFIG = {
    "PORT": 8088,
    "HOST": "0.0.0.0",
    "LLM_MODEL": "Qwen2-7B",
    "API_BASE": "http://192.168.10.14:8000/v1",
    "TOOL_BASE_URL": "http://localhost:5001",
    "API_KEY": "xxxxxxxxxxxxx"
}

def get_config(key: str, args: Optional[argparse.Namespace] = None) -> str:
    """
    获取配置值,优先级:环境变量 > 命令行参数 > 默认值
    
    Args:
        key: 配置键名
        args: 命令行参数对象
    
    Returns:
        str: 配置值
    """
    # 环境变量名转换为大写
    env_key = f"GEO_AGENT_{key.upper()}"
    
    # 优先从环境变量获取
    value = os.getenv(env_key)
    if value is not None:
        return value
    print(f"env_key: {env_key} value: {value}")
    
    # 其次从命令行参数获取
    if args is not None:
        arg_value = getattr(args, key.lower(), None)
        if arg_value is not None:
            return str(arg_value)
    
    # 最后使用默认值
    return str(DEFAULT_CONFIG[key.upper()])

陈正乐 committed
59 60 61
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
tinywell committed
62
    allow_origins=["*"],
陈正乐 committed
63
    allow_credentials=True,
tinywell committed
64 65
    allow_methods=["*"],
    allow_headers=["*"],
陈正乐 committed
66 67
)

tinywell committed
68 69 70 71
class AgentManager:
    def __init__(self):
        self.llm = None
        self.agent = None
tinywell committed
72
        self.router_llm = None
tinywell committed
73 74 75 76 77 78 79 80
    
    def initialize(self, api_key: str, api_base: str, model_name: str, tool_base_url: str):
        self.llm = ChatOpenAI(
            openai_api_key=api_key,
            openai_api_base=api_base,
            model_name=model_name,
            verbose=True
        )
tinywell committed
81
        self.agent = new_rate_agent(self.llm, verbose=True, tool_base_url=tool_base_url)
82 83
        # self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url,version="v0")
        self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url)
文靖昊 committed
84
        self.router_llm = new_router_llm(self.llm)
文靖昊 committed
85 86
        self.re_rewriter_llm = new_re_rewriter_llm(self.llm)

tinywell committed
87 88 89 90 91 92
    def get_llm(self):
        return self.llm
    
    def get_agent(self):
        return self.agent

93 94 95
    def get_rate_agent(self):
        return self.rate_agent

文靖昊 committed
96 97 98
    def get_router_llm(self):
        return self.router_llm

文靖昊 committed
99 100 101
    def get_re_rewriter_llm(self):
        return self.re_rewriter_llm

tinywell committed
102 103
agent_manager = AgentManager()

104 105
@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
tinywell committed
106
    agent = agent_manager.get_agent()
107
    rate_agent = agent_manager.get_rate_agent()
108
    try:
109 110
        # res = agent.exec(prompt_args={"input": chat_request.query})
        res = rate_agent.run(chat_request.query)
111 112 113 114 115 116 117 118 119
    except Exception as e:
        print(f"处理请求失败, 错误信息: {str(e)},请重新提问")
        return {
            'code': 500,
            'data': str(e)
        }
    return {
        'code': 200,
        'data': res
陈正乐 committed
120 121
    }

文靖昊 committed
122 123 124
@app.post('/api/classify')
def classify(chat_request: GeoAgentRateRequest):
    llm = agent_manager.get_router_llm()
文靖昊 committed
125
    re_llm = agent_manager.get_re_rewriter_llm()
文靖昊 committed
126
    try:
127 128 129 130 131 132 133 134 135 136 137
        if chat_request.query is None:
            return {
                'code': 500,
                'data': "缺少必要的参数 query"
            }
        if chat_request.history is None:
           history = ""
        else:
            history = re_llm.extend_history(history=chat_request.history)
        rewrite = re_llm.invoke(chat_request.query, history)
        answer = llm.invoke(rewrite)
tinywell committed
138 139
        res = {
            "datasource": answer.datasource,
140
            "rewrite": rewrite
tinywell committed
141
        }
文靖昊 committed
142 143 144 145 146 147 148 149 150 151
    except Exception as e:
        print(f"分类失败, 错误信息: {str(e)},请重新提问")
        return {
            'code': 500,
            'data': str(e)
        }
    return {
        'code': 200,
        'data': res
    }
tinywell committed
152

tinywell committed
153 154
def main():
    # 命令行参数解析
tinywell committed
155
    parser = argparse.ArgumentParser(description="启动API服务")
tinywell committed
156 157 158 159 160 161 162
    parser.add_argument("--port", type=int, help="API服务端口")
    parser.add_argument("--host", type=str, help="API服务地址")
    parser.add_argument("--llm", type=str, help="LLM模型名称")
    parser.add_argument("--api_base", type=str, help="OpenAI API基础地址")
    parser.add_argument("--tool_base_url", type=str, help="工具服务基础地址")
    parser.add_argument("--api_key", type=str, help="OpenAI API密钥")
    
tinywell committed
163 164
    args = parser.parse_args()

tinywell committed
165 166 167 168 169 170 171 172 173
    # 获取配置
    port = int(get_config("PORT", args))
    host = get_config("HOST", args)
    llm_model = get_config("LLM_MODEL", args)
    api_base = get_config("API_BASE", args)
    tool_base_url = get_config("TOOL_BASE_URL", args)
    api_key = get_config("API_KEY", args)

    # 初始化 agent
tinywell committed
174
    agent_manager.initialize(
tinywell committed
175 176 177 178
        api_key=api_key,
        api_base=api_base,
        model_name=llm_model,
        tool_base_url=tool_base_url
tinywell committed
179
    )
tinywell committed
180

tinywell committed
181 182 183 184 185 186 187 188 189 190 191 192
    # 启动服务
    print(f"Starting server with configuration:")
    print(f"Host: {host}")
    print(f"Port: {port}")
    print(f"LLM Model: {llm_model}")
    print(f"API Base: {api_base}")
    print(f"Tool Base URL: {tool_base_url}")
    
    uvicorn.run(app, host=host, port=port)

if __name__ == "__main__":
    main()