import json import os import sys sys.path.append('../') import argparse from typing import Optional from fastapi import FastAPI, Header from fastapi.middleware.cors import CORSMiddleware import logging import uvicorn from src.server.agent_rate import new_rate_agent, RateAgentV3 from src.server.classify import new_router_llm from src.server.extend_classify import new_extend_classify_llm from src.server.rewrite import new_re_rewriter_llm from src.controller.request import GeoAgentRateRequest from src.utils.logger import setup_logging from langchain_openai import ChatOpenAI # 默认配置 DEFAULT_CONFIG = { "PORT": 8088, "HOST": "0.0.0.0", "LLM_MODEL": "Qwen2-7B", "API_BASE": "http://192.168.10.14:8000/v1", "TOOL_BASE_URL": "http://localhost:5001", "API_KEY": "xxxxxxxxxxxxx", "LOG_LEVEL": "INFO" } def get_config(key: str, args: Optional[argparse.Namespace] = None) -> str: """ 获取配置值,优先级:环境变量 > 命令行参数 > 默认值 Args: key: 配置键名 args: 命令行参数对象 Returns: str: 配置值 """ # 环境变量名转换为大写 env_key = f"GEO_AGENT_{key.upper()}" # 优先从环境变量获取 value = os.getenv(env_key) if value is not None: return value print(f"env_key: {env_key} value: {value}") # 其次从命令行参数获取 if args is not None: arg_value = getattr(args, key.lower(), None) if arg_value is not None: return str(arg_value) # 最后使用默认值 return str(DEFAULT_CONFIG[key.upper()]) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AgentManager: def __init__(self): self.llm = None self.agent = None self.router_llm = None def initialize(self, api_key: str, api_base: str, model_name: str, tool_base_url: str): self.llm = ChatOpenAI( openai_api_key=api_key, openai_api_base=api_base, model_name=model_name, verbose=True ) self.agent = new_rate_agent(self.llm, verbose=True, tool_base_url=tool_base_url) self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url,version="v0") # self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url) self.router_llm = new_router_llm(self.llm) self.re_rewriter_llm = new_re_rewriter_llm(self.llm) self.extend_classify_llm = new_extend_classify_llm(self.llm) def get_llm(self): return self.llm def get_agent(self): return self.agent def get_rate_agent(self): return self.rate_agent def get_router_llm(self): return self.router_llm def get_re_rewriter_llm(self): return self.re_rewriter_llm def get_extend_classify_llm(self): return self.extend_classify_llm agent_manager = AgentManager() @app.post('/api/agent/rate') def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)): agent = agent_manager.get_agent() rate_agent = agent_manager.get_rate_agent() try: # res = agent.exec(prompt_args={"input": chat_request.query}) res = rate_agent.run(chat_request.query) except Exception as e: print(f"处理请求失败, 错误信息: {str(e)},请重新提问") return { 'code': 500, 'data': str(e) } return { 'code': 200, 'data': res } @app.post('/api/classify') def classify(chat_request: GeoAgentRateRequest): llm = agent_manager.get_router_llm() re_llm = agent_manager.get_re_rewriter_llm() extend_llm = agent_manager.get_extend_classify_llm() try: if chat_request.query is None: return { 'code': 500, 'data': "缺少必要的参数 query" } if chat_request.history is None: history = "" else: history = re_llm.extend_history(history=chat_request.history) rewrite = re_llm.invoke(chat_request.query, history) extend = extend_llm.invoke(rewrite) if extend.classify == "yes": return { 'code': 200, 'data': { "datasource":"none", "rewrite": rewrite } } answer = llm.invoke(rewrite) res = { "datasource": answer.datasource, "rewrite": rewrite } except Exception as e: print(f"分类失败, 错误信息: {str(e)},请重新提问") return { 'code': 500, 'data': str(e) } return { 'code': 200, 'data': res } def main(): # 命令行参数解析 parser = argparse.ArgumentParser(description="启动API服务") parser.add_argument("--port", type=int, help="API服务端口") parser.add_argument("--host", type=str, help="API服务地址") parser.add_argument("--llm", type=str, help="LLM模型名称") parser.add_argument("--api_base", type=str, help="OpenAI API基础地址") parser.add_argument("--tool_base_url", type=str, help="工具服务基础地址") parser.add_argument("--api_key", type=str, help="OpenAI API密钥") parser.add_argument("--log_level", type=str, help="日志级别,DEBUG、INFO、WARNING、ERROR、CRITICAL") args = parser.parse_args() # 获取配置 port = int(get_config("PORT", args)) host = get_config("HOST", args) llm_model = get_config("LLM_MODEL", args) api_base = get_config("API_BASE", args) tool_base_url = get_config("TOOL_BASE_URL", args) api_key = get_config("API_KEY", args) log_level = get_config("LOG_LEVEL", args) if log_level is not None: if log_level=="DEBUG": log_level = logging.DEBUG elif log_level=="INFO": log_level = logging.INFO elif log_level=="WARNING": log_level = logging.WARNING elif log_level=="ERROR": log_level = logging.ERROR elif log_level=="CRITICAL": log_level = logging.CRITICAL else: log_level = logging.INFO setup_logging(log_level=log_level) # 初始化 agent agent_manager.initialize( api_key=api_key, api_base=api_base, model_name=llm_model, tool_base_url=tool_base_url ) # 启动服务 print(f"Starting server with configuration:") print(f"Host: {host}") print(f"Port: {port}") print(f"LLM Model: {llm_model}") print(f"API Base: {api_base}") print(f"Tool Base URL: {tool_base_url}") uvicorn.run(app, host=host, port=port) if __name__ == "__main__": main()