from langchain_openai import ChatOpenAI
from rich.console import Console
from rich.table import Table

import sys,os
sys.path.append("../")


from src.server.tool_picker import ToolPicker, ToolRunner
from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool

def run_examples():
    # 初始化 rich console
    console = Console()
    
    # 初始化 LLM
    llm = ChatOpenAI(
        openai_api_key="xxxxxx",
        openai_api_base="http://192.168.10.14:8000/v1",
        model_name="Qwen2-7B",
        verbose=True
    )
    base_url = "http://172.30.0.37:30007"
    # 初始化工具
    tools = [
        RegionRateTool(base_url=base_url),
        RankingRateTool(base_url=base_url),
        MonitorPointTool(base_url=base_url),
    ]

    tools_dict = {tool.name: tool for tool in tools}
    
    # 初始化 ToolPicker
    picker = ToolPicker(llm, tools)
    
    
    # 测试案例和预期结果
    test_cases = [
        {
            "query": "请查询下甘肃省的监测点信息",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省",
                    "three_d_model": "无",
                    "ortho_image": "无",
                }
            }
        },{
            "query": "查询甘肃省滑坡的监测点信息",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省",
                    "disaster_type": "滑坡",
                }
            }
        },{
            "query": "查询甘肃省有三维模型的监测点信息",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省",
                    "three_d_model": "有",
                }
            }
        },{
            "query": "查询甘肃省威胁人口超过30人以上的滑坡的监测点信息",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省",
                    "disaster_threat_people_range_start": 30,
                    "disaster_type": "滑坡",
                }
            }
        },{
            "query": "甘肃省监控点的状态如何?",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省"
                }
            }
        },{
            "query": "贵阳市的雨量传感器有多少",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "贵阳市",
                    "device_type": "雨量",
                    "query_type": "2",
                }
            }
        }
    ]
    
    # 为每个测试案例创建一个表格
    for i, case in enumerate(test_cases, 1):
        console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
        
        table = Table(title=f"查询: {case['query']}")
        table.add_column("项目", style="cyan")
        table.add_column("预期结果", style="green")
        table.add_column("实际结果", style="yellow")
        table.add_column("是否匹配", style="magenta")
        
        try:
            result = picker.pick(case["query"])
            
            # 添加工具比较行
            expected_tool = case["expected"]["tool"]
            actual_tool = result["tool"]
            table.add_row(
                "选择的工具",
                expected_tool,
                actual_tool,
                "✓" if expected_tool == actual_tool else "✗"
            )
            
            # 添加参数比较行
            for param_key in case["expected"]["params"]:
                expected_value = str(case["expected"]["params"][param_key])
                actual_value = str(result["params"].get(param_key, "未提供"))
                table.add_row(
                    f"参数: {param_key}",
                    expected_value,
                    actual_value,
                    "✓" if expected_value == actual_value else "✗"
                )

            # # run tool
            # tool = tools_dict[actual_tool]
            # result = tool.invoke(result["params"])
            # print(result)
            
        except Exception as e:
            table.add_row("错误", "", str(e), "✗")
        
        console.print(table)
        console.print("=" * 80)

if __name__ == "__main__":
    run_examples()