api.py 6.71 KB
Newer Older
文靖昊 committed
1
import json
tinywell committed
2
import os
3 4
import sys
sys.path.append('../')
tinywell committed
5 6
import argparse
from typing import Optional
陈正乐 committed
7 8
from fastapi import FastAPI, Header
from fastapi.middleware.cors import CORSMiddleware
tinywell committed
9
import logging
陈正乐 committed
10
import uvicorn
tinywell committed
11 12


13
from src.server.agent_rate import new_rate_agent, RateAgentV3
文靖昊 committed
14
from src.server.classify import new_router_llm
文靖昊 committed
15
from src.server.extend_classify import new_extend_classify_llm
文靖昊 committed
16
from src.server.rewrite import new_re_rewriter_llm
17
from src.controller.request import GeoAgentRateRequest
tinywell committed
18
from src.utils.logger import setup_logging
19
from langchain_openai import ChatOpenAI
陈正乐 committed
20

tinywell committed
21 22 23 24 25 26 27
# 默认配置
DEFAULT_CONFIG = {
    "PORT": 8088,
    "HOST": "0.0.0.0",
    "LLM_MODEL": "Qwen2-7B",
    "API_BASE": "http://192.168.10.14:8000/v1",
    "TOOL_BASE_URL": "http://localhost:5001",
文靖昊 committed
28 29
    "API_KEY": "xxxxxxxxxxxxx",
    "LOG_LEVEL": "DEBUG"
tinywell committed
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
}

def get_config(key: str, args: Optional[argparse.Namespace] = None) -> str:
    """
    获取配置值,优先级:环境变量 > 命令行参数 > 默认值
    
    Args:
        key: 配置键名
        args: 命令行参数对象
    
    Returns:
        str: 配置值
    """
    # 环境变量名转换为大写
    env_key = f"GEO_AGENT_{key.upper()}"
    
    # 优先从环境变量获取
    value = os.getenv(env_key)
    if value is not None:
        return value
    print(f"env_key: {env_key} value: {value}")
    
    # 其次从命令行参数获取
    if args is not None:
        arg_value = getattr(args, key.lower(), None)
        if arg_value is not None:
            return str(arg_value)
    
    # 最后使用默认值
    return str(DEFAULT_CONFIG[key.upper()])

陈正乐 committed
61 62 63
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
tinywell committed
64
    allow_origins=["*"],
陈正乐 committed
65
    allow_credentials=True,
tinywell committed
66 67
    allow_methods=["*"],
    allow_headers=["*"],
陈正乐 committed
68 69
)

tinywell committed
70 71 72 73
class AgentManager:
    def __init__(self):
        self.llm = None
        self.agent = None
tinywell committed
74
        self.router_llm = None
tinywell committed
75 76 77 78 79 80 81 82
    
    def initialize(self, api_key: str, api_base: str, model_name: str, tool_base_url: str):
        self.llm = ChatOpenAI(
            openai_api_key=api_key,
            openai_api_base=api_base,
            model_name=model_name,
            verbose=True
        )
tinywell committed
83
        self.agent = new_rate_agent(self.llm, verbose=True, tool_base_url=tool_base_url)
84 85
        self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url,version="v0")
        # self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url)
文靖昊 committed
86
        self.router_llm = new_router_llm(self.llm)
文靖昊 committed
87
        self.re_rewriter_llm = new_re_rewriter_llm(self.llm)
文靖昊 committed
88
        self.extend_classify_llm = new_extend_classify_llm(self.llm)
文靖昊 committed
89

tinywell committed
90 91 92 93 94 95
    def get_llm(self):
        return self.llm
    
    def get_agent(self):
        return self.agent

96 97 98
    def get_rate_agent(self):
        return self.rate_agent

文靖昊 committed
99 100 101
    def get_router_llm(self):
        return self.router_llm

文靖昊 committed
102 103 104
    def get_re_rewriter_llm(self):
        return self.re_rewriter_llm

文靖昊 committed
105 106 107
    def get_extend_classify_llm(self):
        return self.extend_classify_llm

tinywell committed
108 109
agent_manager = AgentManager()

110 111
@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
tinywell committed
112
    agent = agent_manager.get_agent()
113
    rate_agent = agent_manager.get_rate_agent()
114
    try:
115 116
        # res = agent.exec(prompt_args={"input": chat_request.query})
        res = rate_agent.run(chat_request.query)
117 118 119 120 121 122 123 124 125
    except Exception as e:
        print(f"处理请求失败, 错误信息: {str(e)},请重新提问")
        return {
            'code': 500,
            'data': str(e)
        }
    return {
        'code': 200,
        'data': res
陈正乐 committed
126 127
    }

文靖昊 committed
128 129 130
@app.post('/api/classify')
def classify(chat_request: GeoAgentRateRequest):
    llm = agent_manager.get_router_llm()
文靖昊 committed
131
    re_llm = agent_manager.get_re_rewriter_llm()
文靖昊 committed
132
    extend_llm = agent_manager.get_extend_classify_llm()
文靖昊 committed
133
    try:
134 135 136 137 138 139 140 141 142 143
        if chat_request.query is None:
            return {
                'code': 500,
                'data': "缺少必要的参数 query"
            }
        if chat_request.history is None:
           history = ""
        else:
            history = re_llm.extend_history(history=chat_request.history)
        rewrite = re_llm.invoke(chat_request.query, history)
文靖昊 committed
144 145 146 147 148 149 150 151 152
        extend = extend_llm.invoke(rewrite)
        if extend.classify == "yes":
            return {
                'code': 200,
                'data': {
                    "datasource":"none",
                    "rewrite": rewrite
                }
            }
153
        answer = llm.invoke(rewrite)
tinywell committed
154 155
        res = {
            "datasource": answer.datasource,
156
            "rewrite": rewrite
tinywell committed
157
        }
文靖昊 committed
158 159 160 161 162 163 164 165 166 167
    except Exception as e:
        print(f"分类失败, 错误信息: {str(e)},请重新提问")
        return {
            'code': 500,
            'data': str(e)
        }
    return {
        'code': 200,
        'data': res
    }
tinywell committed
168

tinywell committed
169 170
def main():
    # 命令行参数解析
tinywell committed
171
    parser = argparse.ArgumentParser(description="启动API服务")
tinywell committed
172 173 174 175 176 177
    parser.add_argument("--port", type=int, help="API服务端口")
    parser.add_argument("--host", type=str, help="API服务地址")
    parser.add_argument("--llm", type=str, help="LLM模型名称")
    parser.add_argument("--api_base", type=str, help="OpenAI API基础地址")
    parser.add_argument("--tool_base_url", type=str, help="工具服务基础地址")
    parser.add_argument("--api_key", type=str, help="OpenAI API密钥")
tinywell committed
178
    parser.add_argument("--log_level", type=str, help="日志级别,DEBUG、INFO、WARNING、ERROR、CRITICAL")
tinywell committed
179
    
tinywell committed
180 181
    args = parser.parse_args()

tinywell committed
182 183 184 185 186 187 188
    # 获取配置
    port = int(get_config("PORT", args))
    host = get_config("HOST", args)
    llm_model = get_config("LLM_MODEL", args)
    api_base = get_config("API_BASE", args)
    tool_base_url = get_config("TOOL_BASE_URL", args)
    api_key = get_config("API_KEY", args)
tinywell committed
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
    log_level = get_config("LOG_LEVEL", args)
    if log_level is not None:
        if log_level=="DEBUG":
            log_level = logging.DEBUG
        elif log_level=="INFO":
            log_level = logging.INFO
        elif log_level=="WARNING":
            log_level = logging.WARNING
        elif log_level=="ERROR":
            log_level = logging.ERROR
        elif log_level=="CRITICAL":
            log_level = logging.CRITICAL
        else:
            log_level = logging.INFO
        setup_logging(log_level=log_level)
tinywell committed
204
    # 初始化 agent
tinywell committed
205
    agent_manager.initialize(
tinywell committed
206 207 208 209
        api_key=api_key,
        api_base=api_base,
        model_name=llm_model,
        tool_base_url=tool_base_url
tinywell committed
210
    )
tinywell committed
211

tinywell committed
212 213 214 215 216 217 218 219 220 221 222 223
    # 启动服务
    print(f"Starting server with configuration:")
    print(f"Host: {host}")
    print(f"Port: {port}")
    print(f"LLM Model: {llm_model}")
    print(f"API Base: {api_base}")
    print(f"Tool Base URL: {tool_base_url}")
    
    uvicorn.run(app, host=host, port=port)

if __name__ == "__main__":
    main()