import json
import os
import sys
sys.path.append('../')
import argparse
from typing import Optional
from fastapi import FastAPI, Header
from fastapi.middleware.cors import CORSMiddleware
import logging
import uvicorn


from src.server.agent_rate import new_rate_agent, RateAgentV3
from src.server.classify import new_router_llm
from src.server.extend_classify import new_extend_classify_llm
from src.server.rewrite import new_re_rewriter_llm
from src.controller.request import GeoAgentRateRequest
from src.utils.logger import setup_logging
from langchain_openai import ChatOpenAI

# 默认配置
DEFAULT_CONFIG = {
    "PORT": 8088,
    "HOST": "0.0.0.0",
    "LLM_MODEL": "Qwen2-7B",
    "API_BASE": "http://192.168.10.14:8000/v1",
    "TOOL_BASE_URL": "http://localhost:5001",
    "API_KEY": "xxxxxxxxxxxxx",
    "LOG_LEVEL": "INFO"
}

def get_config(key: str, args: Optional[argparse.Namespace] = None) -> str:
    """
    获取配置值,优先级:环境变量 > 命令行参数 > 默认值
    
    Args:
        key: 配置键名
        args: 命令行参数对象
    
    Returns:
        str: 配置值
    """
    # 环境变量名转换为大写
    env_key = f"GEO_AGENT_{key.upper()}"
    
    # 优先从环境变量获取
    value = os.getenv(env_key)
    if value is not None:
        return value
    print(f"env_key: {env_key} value: {value}")
    
    # 其次从命令行参数获取
    if args is not None:
        arg_value = getattr(args, key.lower(), None)
        if arg_value is not None:
            return str(arg_value)
    
    # 最后使用默认值
    return str(DEFAULT_CONFIG[key.upper()])

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class AgentManager:
    def __init__(self):
        self.llm = None
        self.agent = None
        self.router_llm = None
    
    def initialize(self, api_key: str, api_base: str, model_name: str, tool_base_url: str):
        self.llm = ChatOpenAI(
            openai_api_key=api_key,
            openai_api_base=api_base,
            model_name=model_name,
            verbose=True
        )
        self.agent = new_rate_agent(self.llm, verbose=True, tool_base_url=tool_base_url)
        self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url,version="v0")
        # self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url)
        self.router_llm = new_router_llm(self.llm)
        self.re_rewriter_llm = new_re_rewriter_llm(self.llm)
        self.extend_classify_llm = new_extend_classify_llm(self.llm)

    def get_llm(self):
        return self.llm
    
    def get_agent(self):
        return self.agent

    def get_rate_agent(self):
        return self.rate_agent

    def get_router_llm(self):
        return self.router_llm

    def get_re_rewriter_llm(self):
        return self.re_rewriter_llm

    def get_extend_classify_llm(self):
        return self.extend_classify_llm

agent_manager = AgentManager()

@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
    agent = agent_manager.get_agent()
    rate_agent = agent_manager.get_rate_agent()
    try:
        # res = agent.exec(prompt_args={"input": chat_request.query})
        res = rate_agent.run(chat_request.query)
    except Exception as e:
        print(f"处理请求失败, 错误信息: {str(e)},请重新提问")
        return {
            'code': 500,
            'data': str(e)
        }
    return {
        'code': 200,
        'data': res
    }

@app.post('/api/classify')
def classify(chat_request: GeoAgentRateRequest):
    llm = agent_manager.get_router_llm()
    re_llm = agent_manager.get_re_rewriter_llm()
    extend_llm = agent_manager.get_extend_classify_llm()
    try:
        if chat_request.query is None:
            return {
                'code': 500,
                'data': "缺少必要的参数 query"
            }
        if chat_request.history is None:
           history = ""
        else:
            history = re_llm.extend_history(history=chat_request.history)
        rewrite = re_llm.invoke(chat_request.query, history)
        extend = extend_llm.invoke(rewrite)
        if extend.classify == "yes":
            return {
                'code': 200,
                'data': {
                    "datasource":"none",
                    "rewrite": rewrite
                }
            }
        answer = llm.invoke(rewrite)
        res = {
            "datasource": answer.datasource,
            "rewrite": rewrite
        }
    except Exception as e:
        print(f"分类失败, 错误信息: {str(e)},请重新提问")
        return {
            'code': 500,
            'data': str(e)
        }
    return {
        'code': 200,
        'data': res
    }

def main():
    # 命令行参数解析
    parser = argparse.ArgumentParser(description="启动API服务")
    parser.add_argument("--port", type=int, help="API服务端口")
    parser.add_argument("--host", type=str, help="API服务地址")
    parser.add_argument("--llm", type=str, help="LLM模型名称")
    parser.add_argument("--api_base", type=str, help="OpenAI API基础地址")
    parser.add_argument("--tool_base_url", type=str, help="工具服务基础地址")
    parser.add_argument("--api_key", type=str, help="OpenAI API密钥")
    parser.add_argument("--log_level", type=str, help="日志级别,DEBUG、INFO、WARNING、ERROR、CRITICAL")
    
    args = parser.parse_args()

    # 获取配置
    port = int(get_config("PORT", args))
    host = get_config("HOST", args)
    llm_model = get_config("LLM_MODEL", args)
    api_base = get_config("API_BASE", args)
    tool_base_url = get_config("TOOL_BASE_URL", args)
    api_key = get_config("API_KEY", args)
    log_level = get_config("LOG_LEVEL", args)
    if log_level is not None:
        if log_level=="DEBUG":
            log_level = logging.DEBUG
        elif log_level=="INFO":
            log_level = logging.INFO
        elif log_level=="WARNING":
            log_level = logging.WARNING
        elif log_level=="ERROR":
            log_level = logging.ERROR
        elif log_level=="CRITICAL":
            log_level = logging.CRITICAL
        else:
            log_level = logging.INFO
        setup_logging(log_level=log_level)
    # 初始化 agent
    agent_manager.initialize(
        api_key=api_key,
        api_base=api_base,
        model_name=llm_model,
        tool_base_url=tool_base_url
    )

    # 启动服务
    print(f"Starting server with configuration:")
    print(f"Host: {host}")
    print(f"Port: {port}")
    print(f"LLM Model: {llm_model}")
    print(f"API Base: {api_base}")
    print(f"Tool Base URL: {tool_base_url}")
    
    uvicorn.run(app, host=host, port=port)

if __name__ == "__main__":
    main()