Commit 625c5230 by tinywell

优化 warn 工具描述,降低与在线率工具混淆

parent abcc91a0
...@@ -22,12 +22,12 @@ class WarningArgs(BaseModel): ...@@ -22,12 +22,12 @@ class WarningArgs(BaseModel):
start_time: str = Field("", description="开始时间 (YYYY-MM-DD HH:mm:ss)") start_time: str = Field("", description="开始时间 (YYYY-MM-DD HH:mm:ss)")
end_time: str = Field("", description="结束时间 (YYYY-MM-DD HH:mm:ss)") end_time: str = Field("", description="结束时间 (YYYY-MM-DD HH:mm:ss)")
region_name: str = Field("", description="地区名称,如果要查询全国数据,请输入空字符串") region_name: str = Field("", description="地区名称,如果要查询全国数据,请输入空字符串")
query_type: str = Field("1", description="查询类型,1 表示查询一段时间内的综合数据,2 表示查询指定年份全年按月度统计的虚警率,处置率等信息") query_type: str = Field("1", description="查询类型,1 表示查询一段时间内的综合数据,2 表示查询指定年份全年按月度统计的虚警、处置等信息")
class WarningTool(BaseTool): class WarningTool(BaseTool):
"""查询预警处置和虚警情况""" """查询预警处置和虚警情况"""
name: str = "warning_statistics" name: str = "warning_statistics"
description: str = "查询一定时间范围内不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持查询指定年份全年按月度统计的虚警率,处置率等信息。" description: str = "查询一定时间范围内不同地区预警处置和虚警情况,包括处置情况、虚警情况、蓝黄橙红数数量等统计。也支持查询指定年份全年按月度统计的虚警情况,处置情况等信息。"
args_schema: Type[BaseModel] = WarningArgs args_schema: Type[BaseModel] = WarningArgs
client: Any = Field(None, exclude=True) client: Any = Field(None, exclude=True)
logger: logging.Logger = Field(None, exclude=True) logger: logging.Logger = Field(None, exclude=True)
......
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
import time
import sys,os import sys,os
sys.path.append("../") sys.path.append("../")
from src.server.tool_picker import ToolPicker from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool, RankingRateTool from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool from src.agent.tool_monitor import MonitorPointTool
from src.agent.tool_warn import WarningTool
def run_examples(): def create_tool_picker():
# 初始化 rich console
console = Console()
# 初始化 LLM
llm = ChatOpenAI( llm = ChatOpenAI(
openai_api_key="xxxxxx", openai_api_key="xxxxxx",
openai_api_base="http://192.168.10.14:8000/v1", openai_api_base="http://192.168.10.14:8000/v1",
model_name="Qwen2-7B", model_name="Qwen2-7B",
verbose=True verbose=True
) )
base_url = "http://172.30.0.37:30007"
# 初始化工具
tools = [
RegionRateTool(base_url=base_url),
RankingRateTool(base_url=base_url),
MonitorPointTool(base_url=base_url),
WarningTool(base_url=base_url),
]
return ToolPicker(llm, tools)
def run_examples():
# 初始化 rich console
console = Console()
# 初始化 LLM
base_url = "http://172.30.0.37:30007" base_url = "http://172.30.0.37:30007"
# 初始化工具 # 初始化工具
tools = [ tools = [
RegionRateTool(base_url=base_url), RegionRateTool(base_url=base_url),
RankingRateTool(base_url=base_url), RankingRateTool(base_url=base_url),
MonitorPointTool(base_url=base_url), MonitorPointTool(base_url=base_url),
WarningTool(base_url=base_url),
] ]
tool_dict = {tool.name: tool for tool in tools} tool_dict = {tool.name: tool for tool in tools}
...@@ -193,9 +210,20 @@ def run_examples(): ...@@ -193,9 +210,20 @@ def run_examples():
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_statistics": True "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_statistics": True
} }
} }
},
{
"query": "贵阳市各县区设备在线率是多少",
"expected": {
"tool": "region_online_rate",
"params": {
"region_name": "贵阳市"
}
}
} }
] ]
tool_selected_success = 0
# 为每个测试案例创建一个表格 # 为每个测试案例创建一个表格
for i, case in enumerate(test_cases, 1): for i, case in enumerate(test_cases, 1):
console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]") console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
...@@ -219,6 +247,9 @@ def run_examples(): ...@@ -219,6 +247,9 @@ def run_examples():
"✓" if expected_tool == actual_tool else "✗" "✓" if expected_tool == actual_tool else "✗"
) )
if expected_tool == actual_tool:
tool_selected_success += 1
# 添加参数比较行 # 添加参数比较行
for param_key in case["expected"]["params"]: for param_key in case["expected"]["params"]:
expected_value = str(case["expected"]["params"][param_key]) expected_value = str(case["expected"]["params"][param_key])
...@@ -233,8 +264,8 @@ def run_examples(): ...@@ -233,8 +264,8 @@ def run_examples():
tool = tool_dict[result["tool"]] tool = tool_dict[result["tool"]]
params = result["params"] params = result["params"]
result = tool.invoke(params) # result = tool.invoke(params)
print(result) # print(result)
except Exception as e: except Exception as e:
table.add_row("错误", "", str(e), "✗") table.add_row("错误", "", str(e), "✗")
...@@ -242,5 +273,44 @@ def run_examples(): ...@@ -242,5 +273,44 @@ def run_examples():
console.print(table) console.print(table)
console.print("=" * 80) console.print("=" * 80)
print(f"工具选择成功率: {tool_selected_success} / {len(test_cases)} = {tool_selected_success / len(test_cases) * 100:.2f}%")
def run_case_17():
console = Console()
picker = create_tool_picker()
success_count = 0
total_time = 0
test_cases_count = 50
for i in range(test_cases_count):
start_time = time.time()
query = "贵阳市各县区设备在线率是多少"
expected = {
"tool": "region_online_rate",
"params": {
"region_name": "贵阳市"
}
}
result = picker.pick(query)
exec_time = time.time() - start_time
total_time += exec_time
table = Table(title=f"第 {i+1:02d} 次测试结果 【{query}】")
table.add_column("查询", style="cyan")
table.add_column("预期工具", style="green")
table.add_column("实际工具", style="yellow")
table.add_column("是否匹配", style="magenta")
table.add_column("耗时", style="magenta")
table.add_row(query, expected["tool"], result["tool"], "✓" if result["tool"] == expected["tool"] else "✗", f"{exec_time:.2f}秒")
if result["tool"] == expected["tool"]:
success_count += 1
console.print(table)
print(f"工具选择成功率: {success_count} / {test_cases_count} = {success_count / test_cases_count * 100:.2f}%")
print(f"平均耗时: {total_time / test_cases_count:.2f}秒 总共耗时: {total_time:.2f}秒")
if __name__ == "__main__": if __name__ == "__main__":
run_examples() # run_examples()
\ No newline at end of file run_case_17()
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
import time
import sys,os import sys,os
sys.path.append("../") sys.path.append("../")
...@@ -79,9 +79,23 @@ def run_examples(): ...@@ -79,9 +79,23 @@ def run_examples():
"query_type": "1", "query_type": "1",
} }
} }
},
{
"query": "2023 年,贵阳市每月处置率是多少",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2023-01-01 00:00:00",
"end_time": "2023-12-31 23:59:59",
"region_name": "贵阳市",
"query_type": "2",
}
}
} }
] ]
success_count = 0
total_time = 0
# 为每个测试案例创建一个表格 # 为每个测试案例创建一个表格
for i, case in enumerate(test_cases, 1): for i, case in enumerate(test_cases, 1):
console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]") console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
...@@ -91,9 +105,13 @@ def run_examples(): ...@@ -91,9 +105,13 @@ def run_examples():
table.add_column("预期结果", style="green") table.add_column("预期结果", style="green")
table.add_column("实际结果", style="yellow") table.add_column("实际结果", style="yellow")
table.add_column("是否匹配", style="magenta") table.add_column("是否匹配", style="magenta")
table.add_column("耗时", style="magenta")
try: try:
start_time = time.time()
result = picker.pick(case["query"]) result = picker.pick(case["query"])
end_time = time.time()
exec_time = end_time - start_time
total_time += exec_time
# 添加工具比较行 # 添加工具比较行
expected_tool = case["expected"]["tool"] expected_tool = case["expected"]["tool"]
...@@ -115,6 +133,10 @@ def run_examples(): ...@@ -115,6 +133,10 @@ def run_examples():
actual_value, actual_value,
"✓" if expected_value == actual_value else "✗" "✓" if expected_value == actual_value else "✗"
) )
if expected_tool == actual_tool:
success_count += 1
table.add_row("耗时", "", f"{exec_time:.2f}秒", "")
# tool = tool_dict[result["tool"]] # tool = tool_dict[result["tool"]]
# params = result["params"] # params = result["params"]
...@@ -128,5 +150,7 @@ def run_examples(): ...@@ -128,5 +150,7 @@ def run_examples():
console.print(table) console.print(table)
console.print("=" * 80) console.print("=" * 80)
print(f"工具选择成功率: {success_count} / {len(test_cases)} = {success_count / len(test_cases) * 100:.2f}%")
print(f"平均耗时: {total_time / len(test_cases):.2f}秒 总共耗时: {total_time:.2f}秒")
if __name__ == "__main__": if __name__ == "__main__":
run_examples() run_examples()
\ No newline at end of file
import requests
questions = [
"贵阳市各县区设备在线率是多少",
"贵阳市各个县区 2024 年 11 月 2 日设备在线率是多少",
"各省实时设备在线率排名",
"各省实时设备在线率排名中,前 5名是哪些省",
"各省实时设备在线率排名中,最后 5 名是哪些省",
"各厂商设备在线率排名中,排名前 10的厂商",
"在厂商设备在线率排名中,上海华测设备在线率排名是多少",
"贵阳市的雨量传感器数量有多少 ",
"贵阳市的监测点数量是多少",
"贵阳市有三维模型的监测点数量是多少",
"今年 ,贵阳市每月的设备在线率是多少?(注:如果需要按“每月”维度展示,问题中需添加 “每月” 字段,增加问题区分度)",
"2024 年 贵阳市每月的滑坡仪在线率是多少",
"贵阳市威胁人数在 200 人以上的滑坡监测点数量有多少",
"2024 年 5 月 20 日到 2024 年 11 月 27 日,贵阳市黄色预警消息占比是多少;",
"2024 年 10 月 1日到 2024 年 10 月31日,贵阳市处置率是多少",
"2024 年5 月到 2024 年11 月,贵阳市每月处置率是多少",
"2023 年,贵阳市每月处置率是多少",
"贵阳市预警等级为红色的滑坡监测点数量有多少"
]
url = "http://localhost:8088/api/agent/rate"
for question in questions:
print(question)
http_response = requests.post(url, json={"query": question})
print(http_response.json())
print("--------------------------------"*5)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment