run_tool_picker_monitor.py 4.53 KB
Newer Older
1 2 3 4 5 6 7 8
from langchain_openai import ChatOpenAI
from rich.console import Console
from rich.table import Table

import sys,os
sys.path.append("../")


tinywell committed
9
from src.server.tool_picker import ToolPicker, ToolRunner
10 11 12 13 14 15 16 17 18 19 20 21 22 23
from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool

def run_examples():
    # 初始化 rich console
    console = Console()
    
    # 初始化 LLM
    llm = ChatOpenAI(
        openai_api_key="xxxxxx",
        openai_api_base="http://192.168.10.14:8000/v1",
        model_name="Qwen2-7B",
        verbose=True
    )
tinywell committed
24
    base_url = "http://172.30.0.37:30007"
25 26
    # 初始化工具
    tools = [
tinywell committed
27 28 29
        RegionRateTool(base_url=base_url),
        RankingRateTool(base_url=base_url),
        MonitorPointTool(base_url=base_url),
30
    ]
tinywell committed
31 32

    tools_dict = {tool.name: tool for tool in tools}
33 34 35 36
    
    # 初始化 ToolPicker
    picker = ToolPicker(llm, tools)
    
tinywell committed
37
    
38 39 40 41 42 43 44 45
    # 测试案例和预期结果
    test_cases = [
        {
            "query": "请查询下甘肃省的监测点信息",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省",
tinywell committed
46 47
                    "three_d_model": "无",
                    "ortho_image": "无",
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
                }
            }
        },{
            "query": "查询甘肃省滑坡的监测点信息",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省",
                    "disaster_type": "滑坡",
                }
            }
        },{
            "query": "查询甘肃省有三维模型的监测点信息",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省",
tinywell committed
65
                    "three_d_model": "有",
66 67 68 69 70 71 72 73 74 75 76 77
                }
            }
        },{
            "query": "查询甘肃省威胁人口超过30人以上的滑坡的监测点信息",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "甘肃省",
                    "disaster_threat_people_range_start": 30,
                    "disaster_type": "滑坡",
                }
            }
78 79 80 81 82
        },{
            "query": "甘肃省监控点的状态如何?",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
83
                    "key": "甘肃省"
84 85
                }
            }
86 87 88 89 90 91 92 93 94 95 96
        },{
            "query": "贵阳市的雨量传感器有多少",
            "expected": {
                "tool": "monitor_points_query",
                "params": {
                    "key": "贵阳市",
                    "device_type": "雨量",
                    "query_type": "2",
                }
            }
        }
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
    ]
    
    # 为每个测试案例创建一个表格
    for i, case in enumerate(test_cases, 1):
        console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
        
        table = Table(title=f"查询: {case['query']}")
        table.add_column("项目", style="cyan")
        table.add_column("预期结果", style="green")
        table.add_column("实际结果", style="yellow")
        table.add_column("是否匹配", style="magenta")
        
        try:
            result = picker.pick(case["query"])
            
            # 添加工具比较行
            expected_tool = case["expected"]["tool"]
            actual_tool = result["tool"]
            table.add_row(
                "选择的工具",
                expected_tool,
                actual_tool,
                "✓" if expected_tool == actual_tool else "✗"
            )
            
            # 添加参数比较行
            for param_key in case["expected"]["params"]:
                expected_value = str(case["expected"]["params"][param_key])
                actual_value = str(result["params"].get(param_key, "未提供"))
                table.add_row(
                    f"参数: {param_key}",
                    expected_value,
                    actual_value,
                    "✓" if expected_value == actual_value else "✗"
                )
tinywell committed
132

133 134 135 136
            # # run tool
            # tool = tools_dict[actual_tool]
            # result = tool.invoke(result["params"])
            # print(result)
tinywell committed
137
            
138 139 140 141 142 143 144 145
        except Exception as e:
            table.add_row("错误", "", str(e), "✗")
        
        console.print(table)
        console.print("=" * 80)

if __name__ == "__main__":
    run_examples()