from langchain_openai import ChatOpenAI from rich.console import Console from rich.table import Table import sys,os sys.path.append("../") from src.server.tool_picker import ToolPicker, ToolRunner from src.agent.tool_rate import RegionRateTool, RankingRateTool from src.agent.tool_monitor import MonitorPointTool def run_examples(): # 初始化 rich console console = Console() # 初始化 LLM llm = ChatOpenAI( openai_api_key="xxxxxx", openai_api_base="http://192.168.10.14:8000/v1", model_name="Qwen2-7B", verbose=True ) base_url = "http://172.30.0.37:30007" # 初始化工具 tools = [ RegionRateTool(base_url=base_url), RankingRateTool(base_url=base_url), MonitorPointTool(base_url=base_url), ] tools_dict = {tool.name: tool for tool in tools} # 初始化 ToolPicker picker = ToolPicker(llm, tools) # 测试案例和预期结果 test_cases = [ { "query": "请查询下甘肃省的监测点信息", "expected": { "tool": "monitor_points_query", "params": { "key": "甘肃省", "three_d_model": "无", "ortho_image": "无", } } },{ "query": "查询甘肃省滑坡的监测点信息", "expected": { "tool": "monitor_points_query", "params": { "key": "甘肃省", "disaster_type": "滑坡", } } },{ "query": "查询甘肃省有三维模型的监测点信息", "expected": { "tool": "monitor_points_query", "params": { "key": "甘肃省", "three_d_model": "有", } } },{ "query": "查询甘肃省威胁人口超过30人以上的滑坡的监测点信息", "expected": { "tool": "monitor_points_query", "params": { "key": "甘肃省", "disaster_threat_people_range_start": 30, "disaster_type": "滑坡", } } },{ "query": "甘肃省监控点的状态如何?", "expected": { "tool": "monitor_points_query", "params": { "key": "甘肃省" } } },{ "query": "贵阳市的雨量传感器有多少", "expected": { "tool": "monitor_points_query", "params": { "key": "贵阳市", "device_type": "雨量", "query_type": "2", } } } ] # 为每个测试案例创建一个表格 for i, case in enumerate(test_cases, 1): console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]") table = Table(title=f"查询: {case['query']}") table.add_column("项目", style="cyan") table.add_column("预期结果", style="green") table.add_column("实际结果", style="yellow") table.add_column("是否匹配", style="magenta") try: result = picker.pick(case["query"]) # 添加工具比较行 expected_tool = case["expected"]["tool"] actual_tool = result["tool"] table.add_row( "选择的工具", expected_tool, actual_tool, "✓" if expected_tool == actual_tool else "✗" ) # 添加参数比较行 for param_key in case["expected"]["params"]: expected_value = str(case["expected"]["params"][param_key]) actual_value = str(result["params"].get(param_key, "未提供")) table.add_row( f"参数: {param_key}", expected_value, actual_value, "✓" if expected_value == actual_value else "✗" ) # # run tool # tool = tools_dict[actual_tool] # result = tool.invoke(result["params"]) # print(result) except Exception as e: table.add_row("错误", "", str(e), "✗") console.print(table) console.print("=" * 80) if __name__ == "__main__": run_examples()