Commit 7ded04d9 by tinywell

工具选择及参数提取测试;补充预警信息获取工具;

parent d1bc8007
...@@ -11,13 +11,17 @@ const_base_url = "http://localhost:5001" ...@@ -11,13 +11,17 @@ const_base_url = "http://localhost:5001"
const_url_point = "/cigem/getMonitorPointAll" const_url_point = "/cigem/getMonitorPointAll"
const_url_rate = "/cigem/getAvgOnlineRate" const_url_rate = "/cigem/getAvgOnlineRate"
const_url_rate_ranking = "/cigem/getOnlineRateRank" const_url_rate_ranking = "/cigem/getOnlineRateRank"
const_url_rate_month = "/cigem/getAvgOnlineRateOfMonth"
const_url_device_list = "/cigem/getMonitorDeviceList"
const_url_warning = "/cigem/getWarningStatistics"
class BaseResponse(BaseModel, Generic[T]): class BaseResponse(BaseModel, Generic[T]):
"""通用响应模型""" """通用响应模型"""
type: int type: int
resultcode: int resultcode: int
message: str message: str
resultdata: T resultdata: Optional[T] = None
otherinfo: Optional[str] = None
class BaseHttpClient: class BaseHttpClient:
"""基础HTTP客户端""" """基础HTTP客户端"""
...@@ -52,35 +56,75 @@ class MonitorPoint(BaseModel): ...@@ -52,35 +56,75 @@ class MonitorPoint(BaseModel):
ELEVATION: str ELEVATION: str
BUILDUNIT: str BUILDUNIT: str
MONITORUNIT: str MONITORUNIT: str
YWUNIT: str YWUNIT: Optional[str] = None
SGDW: Optional[str] = None SGDW: Optional[str] = None
MANUFACTURER: str = "" MANUFACTURER: Optional[str] = None
MONITORTYPE: str MONITORTYPE: Optional[str] = None
class Sensor(BaseModel):
"""传感器数据模型"""
SENSORCODE: str
LOCATION: Optional[str] = None
LATITUDE: Optional[str] = None
LONGITUDE: Optional[str] = None
DEVICETYPENAME: Optional[str] = None
MANUFACTURER: Optional[str] = None
class Device(BaseModel):
"""设备数据模型"""
DEVICECODE: str
SN: str
LOCATION: Optional[str] = None
LATITUDE: Optional[str] = None
LONGITUDE: Optional[str] = None
DEVICETYPENAME: Optional[str] = None
MANUFACTURER: Optional[str] = None
MONITORPOINTNAME: Optional[str] = None
SensorList: List[Sensor]
class MonitorClient(BaseHttpClient): class MonitorClient(BaseHttpClient):
"""监测点查询客户端""" """监测点查询客户端"""
async def query_points(self, key: str) -> BaseResponse[List[MonitorPoint]]: async def query_points(self, key: str) -> BaseResponse[List]:
"""异步查询监测点信息""" """异步查询监测点信息"""
data = await self._request_async( data = await self._request_async(
"POST", "POST",
const_url_point, const_url_point,
json={"key": key} json={"key": key}
) )
return BaseResponse[List[MonitorPoint]](**data) return BaseResponse[List](**data)
def query_points_sync(self, key: str) -> BaseResponse[List[MonitorPoint]]: def query_points_sync(self, key: str) -> BaseResponse[List]:
"""同步查询监测点信息""" """同步查询监测点信息"""
data = self._request_sync( data = self._request_sync(
"POST", "POST",
const_url_point, const_url_point,
json={"key": key} json={"key": key}
) )
return BaseResponse[List[MonitorPoint]](**data) return BaseResponse[List](**data)
def query_device_list(self, code: str) -> BaseResponse[List]:
"""同步查询设备列表"""
data = self._request_sync(
"POST",
const_url_device_list,
json={"monitorpointcode": code}
)
print(data)
return BaseResponse[List](**data)
async def query_device_list_async(self, code: str) -> BaseResponse[List]:
"""异步查询设备列表"""
data = await self._request_async(
"POST",
const_url_device_list,
json={"monitorpointcode": code}
)
return BaseResponse[List](**data)
# 示例:添加新的数据接口客户端 # 示例:添加新的数据接口客户端
class RateClient(BaseHttpClient): class RateClient(BaseHttpClient):
"""在线率查询客户端""" """在线率查询客户端"""
async def query_rates(self, areacode: str, startDate: str, endDate: str) -> BaseResponse[Dict]: async def query_rates(self, areacode: str, startDate: str, endDate: str) -> BaseResponse[List]:
"""异步查询在线率信息""" """异步查询在线率信息"""
data = await self._request_async( data = await self._request_async(
"POST", "POST",
...@@ -91,7 +135,7 @@ class RateClient(BaseHttpClient): ...@@ -91,7 +135,7 @@ class RateClient(BaseHttpClient):
'endDate': endDate 'endDate': endDate
} }
) )
return BaseResponse[Dict](**data) return BaseResponse[List](**data)
def query_rates_sync(self, areacode: str, startDate: str, endDate: str) -> BaseResponse[List]: def query_rates_sync(self, areacode: str, startDate: str, endDate: str) -> BaseResponse[List]:
"""同步查询在线率信息""" """同步查询在线率信息"""
...@@ -124,28 +168,43 @@ class RateClient(BaseHttpClient): ...@@ -124,28 +168,43 @@ class RateClient(BaseHttpClient):
) )
return BaseResponse[List](**data) return BaseResponse[List](**data)
# 使用示例 def query_rates_month_sync(self, year: str, areaCode: str, typeArr: List[int]) -> BaseResponse[List]:
async def example_async_usage(): """同步查询按月度统计的在线率信息"""
# 监测点查询示例 data = self._request_sync(
monitor_client = MonitorClient() "POST",
try: const_url_rate_month,
response = await monitor_client.query_points("湖南") json={
if response.resultcode == 1 and response.resultdata: 'year': year,
for point in response.resultdata: 'areaCode': areaCode,
print(f"监测点名称: {point.MONITORPOINTNAME}") 'typeArr': typeArr
print(f"位置: {point.LOCATION}") }
print("---") )
except httpx.HTTPError as e: return BaseResponse[List](**data)
print(f"HTTP请求错误: {e}")
# 在线率查询示例 class WarningClient(BaseHttpClient):
rate_client = RateClient() """预警查询客户端"""
try: def query_warning_statistics(self, startDate: str, endDate: str, areaCode: str) -> BaseResponse[List]:
response = await rate_client.query_rates("520100", "2024-01-01", "2024-01-31") """同步查询预警统计信息"""
if response.resultcode == 1: data = self._request_sync(
print(f"在线率数据: {response.resultdata}") "POST",
except httpx.HTTPError as e: const_url_warning,
print(f"HTTP请求错误: {e}") json={
'startDate': startDate,
if __name__ == "__main__": 'endDate': endDate,
asyncio.run(example_async_usage()) 'areaCode': areaCode
}
)
return BaseResponse[List](**data)
def query_warning_month_statistics(self, year: str, areaCode: str) -> BaseResponse[List]:
"""同步查询按月度统计的预警统计信息"""
data = self._request_sync(
"POST",
const_url_warning,
json={
'year': year,
'areaCode': areaCode
}
)
return BaseResponse[List](**data)
\ No newline at end of file
...@@ -14,15 +14,24 @@ class MonitorPointResponse(): ...@@ -14,15 +14,24 @@ class MonitorPointResponse():
class MonitorPointArgs(BaseModel): class MonitorPointArgs(BaseModel):
"""监测点查询参数""" """监测点查询参数"""
key: str = Field(..., description="行政区划名称(省/市/区县级别均可,只需要最后一级,如岳麓区,不需要长沙市)") key: str = Field(..., description="行政区划名称(省/市级别均可,只需要最后一级,如长沙市,不需要湖南省)")
year: int = Field(None, description="年份,未提及则为今年")
disaster_type: str = Field(None, description="灾害类型,如崩塌、滑坡、泥石流、地面塌陷、地面沉降、地裂缝等,未提及则为空")
three_d_model: bool = Field(False, description="是否需要三维模型,默认不需要")
ortho_image: bool = Field(False, description="是否需要正射影像,默认不需要")
disaster_threat_people_range_start: int = Field(None, description="灾害威胁人数范围起始值,如100,未提及则为空")
disaster_threat_people_range_end: int = Field(None, description="灾害威胁人数范围结束值,如200,未提及则为空")
disaster_scale: str = Field(None, description="灾害规模,灾害为崩塌、滑坡、泥石流时表示体积,灾害为地面塌陷、地面沉降时表示面积,为地裂缝时表示长度,未提及则为空")
device_type: str = Field(None, description="设备类型(例如 加速度、位移、温度、湿度、裂缝计等),默认为空")
class MonitorPointTool(BaseTool): class MonitorPointTool(BaseTool):
"""查询监测点信息的工具""" """查询监测点信息的工具"""
name:str = "monitor_points_query" name:str = "monitor_points_query"
description:str = """查询指定行政区划的监测点信息。 description:str = """查询指定行政区划的监测点信息。
可以查询任意省/市/区县级别的监测点数据。 可以查询任意省/市/区县级别的监测点数据,也可以通过灾害类型、灾害规模、灾害威胁人数范围、设备类型等条件查询
输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。 输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。
返回该区域内的监测点列表,包含位置、经纬度等详细信息。 返回该区域内的监测点列表,包含位置、经纬度等详细信息。
还可以查询监测点下相关监测设备信息,比如设备数量等。
""" """
args_schema: Type[BaseModel] = MonitorPointArgs args_schema: Type[BaseModel] = MonitorPointArgs
client: Any = Field(None, exclude=True) client: Any = Field(None, exclude=True)
...@@ -40,18 +49,37 @@ class MonitorPointTool(BaseTool): ...@@ -40,18 +49,37 @@ class MonitorPointTool(BaseTool):
self.client = MonitorClient(base_url=base_url) self.client = MonitorClient(base_url=base_url)
self.logger = get_logger("MonitorPointTool") self.logger = get_logger("MonitorPointTool")
def _run(self, key: str) -> Dict[str, Any]: def _run(self, key: str, device_required: bool = False, device_type: str = None) -> Dict[str, Any]:
""" """
执行监测点查询 执行监测点查询
Args: Args:
key: 行政区划名称 key: 行政区划名称
year: 年份
disaster_type: 灾害类型
three_d_model: 是否需要三维模型
ortho_image: 是否需要正射影像
disaster_threat_people_range_start: 灾害威胁人数范围起始值
disaster_threat_people_range_end: 灾害威胁人数范围结束值
disaster_scale: 灾害规模
device_required: 是否需要设备相关信息
device_type: 设备类型
Returns: Returns:
Dict: 包含查询结果的字典 Dict: 包含查询结果的字典
""" """
try: try:
self.logger.info(f"开始查询监测点信息,区域: {key}") self.logger.info(f"开始查询监测点信息,区域: {key}")
code = ""
if key != "":
self.logger.debug(f"查找区域代码: {key}")
codes = code_tool.find_code(key)
if codes is None or len(codes) == 0:
error_msg = f'未找到匹配的区域代码: {key}'
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
code = codes[0][1]
self.logger.debug(f"找到区域代码: {code}")
response = self.client.query_points_sync(key) response = self.client.query_points_sync(key)
self.logger.debug(f"API响应: {response}") self.logger.debug(f"API响应: {response}")
...@@ -62,7 +90,7 @@ class MonitorPointTool(BaseTool): ...@@ -62,7 +90,7 @@ class MonitorPointTool(BaseTool):
'code': 400, 'code': 400,
'message': error_msg 'message': error_msg
} }
# 提取关键信息并格式化 # 提取关键信息并格式化
points_info = [] points_info = []
for point in response.resultdata: for point in response.resultdata:
...@@ -99,19 +127,6 @@ class MonitorPointTool(BaseTool): ...@@ -99,19 +127,6 @@ class MonitorPointTool(BaseTool):
'message': error_msg 'message': error_msg
} }
def _arun(self, key: str) -> Dict[str, Any]:
"""
异步执行监测点查询
Args:
key: 行政区划名称
Returns:
Dict: 包含查询结果的字典
"""
self.logger.warning("异步查询方法未实现")
raise NotImplementedError("异步查询暂未实现")
def to_markdown(self, data: List[Dict[str, Any]]) -> str: def to_markdown(self, data: List[Dict[str, Any]]) -> str:
"""将数据转换为 markdown 表格""" """将数据转换为 markdown 表格"""
self.logger.debug("开始生成 markdown 表格") self.logger.debug("开始生成 markdown 表格")
......
...@@ -69,6 +69,9 @@ class RegionRateArgs(BaseModel): ...@@ -69,6 +69,9 @@ class RegionRateArgs(BaseModel):
region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串") region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串")
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)") start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)") end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
month_statistics: bool = Field(False, description="是否按月度查看一年内的在线率统计结果,默认不需要")
device_items: str = Field(None, description="监测项集合,多个逗号隔开,支持:滑坡仪、设备、传感器、雨量、地表位移、裂缝、倾角、加速度、土壤含水率、泥水位,默认为空")
manufacturer_name: str = Field(None, description="设备厂商,默认为空")
class RegionRateTool(BaseRateTool): class RegionRateTool(BaseRateTool):
"""查询全国或者特定地区设备在线率的工具""" """查询全国或者特定地区设备在线率的工具"""
...@@ -82,12 +85,12 @@ class RegionRateTool(BaseRateTool): ...@@ -82,12 +85,12 @@ class RegionRateTool(BaseRateTool):
self.client = RateClient(base_url=base_url) self.client = RateClient(base_url=base_url)
self.logger.info(f"初始化 RegionRateTool,base_url: {base_url}") self.logger.info(f"初始化 RegionRateTool,base_url: {base_url}")
def _run(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]: def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]:
return self.get_region_online_rate(start_time, end_time, region_name) return self.get_region_online_rate(start_time, end_time, region_name, month_statistics, device_items, manufacturer_name)
def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]: def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]:
agent_start = time.time() agent_start = time.time()
self.logger.info(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}") self.logger.info(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}, 厂商: {manufacturer_name}, 监测项: {device_items}, 按月度查询统计: {month_statistics}")
code = "" code = ""
if region_name != "": if region_name != "":
...@@ -100,44 +103,58 @@ class RegionRateTool(BaseRateTool): ...@@ -100,44 +103,58 @@ class RegionRateTool(BaseRateTool):
code = codes[0][1] code = codes[0][1]
self.logger.debug(f"找到区域代码: {code}") self.logger.debug(f"找到区域代码: {code}")
try: if month_statistics: # 按月度查询一年内的统计结果
start = time.time() year = start_time.split('-')[0]
df = self.client.query_rates_sync(code, start_time, end_time) self.logger.debug(f"按月度查询,年份: {year}")
query_time = time.time() - start
self.logger.debug(f"API调用耗时: {query_time:.2f}秒") else:
try:
if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0: start = time.time()
error_msg = f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限' df = self.client.query_rates_sync(code, start_time, end_time)
self.logger.warning(error_msg) query_time = time.time() - start
self.logger.debug(f"API调用耗时: {query_time:.2f}秒")
if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
error_msg = f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg} return {'code': 400, 'message': error_msg}
self.logger.debug(f"查询结果: {df.resultdata}") self.logger.debug(f"查询结果: {df.resultdata}")
markdown = self.to_markdown(df.resultdata) markdown = self.to_markdown(df.resultdata)
data = { data = {
'region': region_name, 'region': region_name,
'region_code': code, 'region_code': code,
'rate_data': df.resultdata, 'rate_data': df.resultdata,
'markdown': markdown, 'markdown': markdown,
'date_range': { 'date_range': {
'start': start_time, 'start': start_time,
'end': end_time 'end': end_time
}
} }
}
total_time = time.time() - agent_start total_time = time.time() - agent_start
self.logger.info(f"查询完成,总耗时: {total_time:.2f}秒") self.logger.info(f"查询完成 {region_name}(code: {code}) 在线率,总耗时: {total_time:.2f}秒")
return data return data
except Exception as e: except Exception as e:
self.logger.error(f"查询失败: {str(e)}", exc_info=True) self.logger.error(f"查询失败: {str(e)}", exc_info=True)
raise raise
def to_llm(self, region_name: str,start_time: str, end_time: str, data: Dict[str, Any]) -> str:
"""将数据转换为 LLM 可理解的格式"""
self.logger.debug("开始将数据转换为 LLM 可理解的格式")
llm_data = f"共找到{len(data)}条数据,在{start_time}至{end_time}期间,{region_name}的在线率数据如下:\n"
for index, item in enumerate(data):
llm_data += f"{index+1}. {item['name']} 的在线率为 {item['rate']}\n"
return llm_data
def to_markdown(self, data: List[Dict[str, Any]]) -> str: def to_markdown(self, data: List[Dict[str, Any]]) -> str:
"""将数据转换为 markdown 表格""" """将数据转换为 markdown 表格"""
self.logger.debug("开始生成 markdown 表格") self.logger.debug("开始生成 markdown 表格")
markdown = """ markdown = """
| 序号 | 日期 | 在线率 | | 序号 | 地区 | 在线率 |
| --- | --- | --- | | --- | --- | --- |
""" """
for index, row in enumerate(data): for index, row in enumerate(data):
...@@ -145,8 +162,6 @@ class RegionRateTool(BaseRateTool): ...@@ -145,8 +162,6 @@ class RegionRateTool(BaseRateTool):
self.logger.debug("markdown 表格生成完成") self.logger.debug("markdown 表格生成完成")
return markdown return markdown
class RankingRateArgs(BaseModel): class RankingRateArgs(BaseModel):
"""排名查询参数""" """排名查询参数"""
rate_type: int = Field(..., description="排序类型,用于指定查询的排名类别。1表示省份排名,2表示厂商排名") rate_type: int = Field(..., description="排序类型,用于指定查询的排名类别。1表示省份排名,2表示厂商排名")
......
"""
不同时间段不同地区预警处置和虚警情况
1.地区处置率数量和占比统计
2.地区虚警率数量和占比统计
3.地区蓝黄橙红数数量和占比统计
"""
from pydantic import BaseModel, Field
from typing import Any, Dict, Type
from .http_tools import WarningClient, const_base_url
class WarningArgs(BaseModel):
"""预警查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串")
month_statistics: bool = Field(False, description="是否按月度查询一年内的统计结果,默认不需要")
class WarningTool(BaseTool):
"""查询预警处置和虚警情况"""
name: str = "warning_statistics"
description: str = "查询不同时间段不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持按月度统计一年内的虚警率,处置率。"
args_schema: Type[BaseModel] = WarningArgs
client: Any = Field(None, exclude=True)
def __init__(self, base_url: str = const_base_url, **data):
super().__init__(**data)
self.client = WarningClient(base_url=base_url)
self.logger.info(f"初始化 WarningTool,base_url: {base_url}")
def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False) -> Dict[str, Any]:
return self.get_warning_statistics(start_time, end_time, region_name, month_statistics)
def get_warning_statistics(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False) -> Dict[str, Any]:
pass
\ No newline at end of file
from langchain_openai import ChatOpenAI
from rich.console import Console
from rich.table import Table
import sys,os
sys.path.append("../")
from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
def run_examples():
# 初始化 rich console
console = Console()
# 初始化 LLM
llm = ChatOpenAI(
openai_api_key="xxxxxx",
openai_api_base="http://192.168.10.14:8000/v1",
model_name="Qwen2-7B",
verbose=True
)
# 初始化工具
tools = [
RegionRateTool(),
RankingRateTool(),
MonitorPointTool(),
]
# 初始化 ToolPicker
picker = ToolPicker(llm, tools)
# 测试案例和预期结果
test_cases = [
{
"query": "请查询下甘肃省的监测点信息",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省",
"three_d_model": False,
"ortho_image": False,
}
}
},{
"query": "查询甘肃省滑坡的监测点信息",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省",
"disaster_type": "滑坡",
}
}
},{
"query": "查询甘肃省有三维模型的监测点信息",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省",
"three_d_model": True,
}
}
},{
"query": "查询甘肃省威胁人口超过30人以上的滑坡的监测点信息",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省",
"disaster_threat_people_range_start": 30,
"disaster_type": "滑坡",
}
}
}
]
# 为每个测试案例创建一个表格
for i, case in enumerate(test_cases, 1):
console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
table = Table(title=f"查询: {case['query']}")
table.add_column("项目", style="cyan")
table.add_column("预期结果", style="green")
table.add_column("实际结果", style="yellow")
table.add_column("是否匹配", style="magenta")
try:
result = picker.pick(case["query"])
# 添加工具比较行
expected_tool = case["expected"]["tool"]
actual_tool = result["tool"]
table.add_row(
"选择的工具",
expected_tool,
actual_tool,
"✓" if expected_tool == actual_tool else "✗"
)
# 添加参数比较行
for param_key in case["expected"]["params"]:
expected_value = str(case["expected"]["params"][param_key])
actual_value = str(result["params"].get(param_key, "未提供"))
table.add_row(
f"参数: {param_key}",
expected_value,
actual_value,
"✓" if expected_value == actual_value else "✗"
)
except Exception as e:
table.add_row("错误", "", str(e), "✗")
console.print(table)
console.print("=" * 80)
if __name__ == "__main__":
run_examples()
\ No newline at end of file
from langchain_openai import ChatOpenAI
from rich.console import Console
from rich.table import Table
import sys,os
sys.path.append("../")
from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
def run_examples():
# 初始化 rich console
console = Console()
# 初始化 LLM
llm = ChatOpenAI(
openai_api_key="xxxxxx",
openai_api_base="http://192.168.10.14:8000/v1",
model_name="Qwen2-7B",
verbose=True
)
# 初始化工具
tools = [
RegionRateTool(),
RankingRateTool(),
MonitorPointTool(),
]
# 初始化 ToolPicker
picker = ToolPicker(llm, tools)
# 测试案例和预期结果
test_cases = [
{
"query": "请分析下今天全国各地区在线率情况",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-11-19",
"end_time": "2024-11-19",
"region_name": "",
"month_required": False
}
}
},
{
"query": "请分析下今天甘肃省设备在线率情况",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-11-19",
"end_time": "2024-11-19",
"region_name": "甘肃省",
"month_required": False
}
}
},
{
"query": "查询2024年11月13日各地区排名情况",
"expected": {
"tool": "online_rate_ranking",
"params": {
"rate_type": 1
}
}
},
{
"query": "查询各厂商在线率排名情况",
"expected": {
"tool": "online_rate_ranking",
"params": {
"rate_type": 2
}
}
},
{
"query": "甘肃省监控点的状态如何?",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省"
}
}
},
{
"query": "2023年甘肃省每月的设备在线率分别是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01",
"end_time": "2023-12-31",
"region_name": "甘肃省",
"month_required": True
}
}
},
{
"query": "查询2024年甘肃省各月在线率",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-01-01",
"end_time": "2024-12-31",
"region_name": "甘肃省",
"month_required": True
}
}
},
{
"query": "2024年10月15日,成都市武侯区的设备在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-10-15",
"end_time": "2024-10-15",
"region_name": "成都市武侯区",
"month_required": False
}
}
},
{
"query": "2024年,成都市武侯区的设备在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-01-01", "region_name": "成都市武侯区", "month_required": False
}
}
},
{
"query": "2023年甘肃省每月的设备在线率分别是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "甘肃省", "month_required": True
}
}
},
{
"query": "2023年甘肃省按月统计设备在线率?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "region_name": "甘肃省", "month_required": True
}
}
},
{
"query": "2023年全国每个月的设备在线率",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "region_name": "", "month_required": True
}
}
},
{
"query": "2023年1月-2023年12月期间西藏实验点在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": False
}
}
},
{
"query": "2023年1月-2023年12月期间西藏实验点各月在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True
}
}
},
{
"query": "2022年各个月设备在线率统计;",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2022-01-01", "end_time": "2022-12-31", "region_name": "", "month_required": True
}
}
},
{
"query": "2023年1月-2023年12月期间西藏实验点每个月在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True
}
}
},
{
"query": "2023年1月-2023年12月期间西藏实验点每个月在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True
}
}
}
]
# 为每个测试案例创建一个表格
for i, case in enumerate(test_cases, 1):
console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
table = Table(title=f"查询: {case['query']}")
table.add_column("项目", style="cyan")
table.add_column("预期结果", style="green")
table.add_column("实际结果", style="yellow")
table.add_column("是否匹配", style="magenta")
try:
result = picker.pick(case["query"])
# 添加工具比较行
expected_tool = case["expected"]["tool"]
actual_tool = result["tool"]
table.add_row(
"选择的工具",
expected_tool,
actual_tool,
"✓" if expected_tool == actual_tool else "✗"
)
# 添加参数比较行
for param_key in case["expected"]["params"]:
expected_value = str(case["expected"]["params"][param_key])
actual_value = str(result["params"].get(param_key, "未提供"))
table.add_row(
f"参数: {param_key}",
expected_value,
actual_value,
"✓" if expected_value == actual_value else "✗"
)
except Exception as e:
table.add_row("错误", "", str(e), "✗")
console.print(table)
console.print("=" * 80)
if __name__ == "__main__":
run_examples()
\ No newline at end of file
...@@ -7,7 +7,7 @@ from unittest.mock import Mock ...@@ -7,7 +7,7 @@ from unittest.mock import Mock
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from src.server.tool_picker import ToolPicker from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool,RankingRateTool from src.agent.tool_rate import RegionRateTool,RankingRateTool,MonthRateTool
from src.agent.tool_monitor import MonitorPointTool from src.agent.tool_monitor import MonitorPointTool
@pytest.fixture @pytest.fixture
...@@ -29,27 +29,27 @@ def mock_llm(): ...@@ -29,27 +29,27 @@ def mock_llm():
@pytest.mark.parametrize("query, expected_response, expected_tool", [ @pytest.mark.parametrize("query, expected_response, expected_tool", [
( (
"请分析下今天全国各地区在线率情况", "请分析下今天全国各地区在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}}, {"tool": "region_online_rate", "params": {"start_time": "2024-11-19", "end_time": "2024-11-19", "region_name": "", "month_required": False}},
"region_online_rate" "region_online_rate"
), ),
( (
"请分析下今天甘肃省设备在线率情况", "请分析下今天甘肃省设备在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": "甘肃省"}}, {"tool": "region_online_rate", "params": {"start_time": "2024-11-19", "end_time": "2024-11-19", "region_name": "甘肃省","month_required":False }},
"region_online_rate" "region_online_rate"
), ),
( (
"查询今年三季度甘肃省设备在线率情况", "查询今年三季度甘肃省设备在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-07-01", "end_time": "2024-09-30", "region_name": "甘肃省"}}, {"tool": "region_online_rate", "params": {"start_time": "2024-07-01", "end_time": "2024-09-30", "region_name": "甘肃省","month_required":False }},
"region_online_rate" "region_online_rate"
), ),
( (
"查询2024年11月13日各地区排名情况", "查询2024年11月13日各地区排名情况",
{"tool": "online_rate_ranking", "params": {"rate_type": "1"}}, {"tool": "online_rate_ranking", "params": {"rate_type": 1}},
"online_rate_ranking" "online_rate_ranking"
), ),
( (
"查询各厂商在线率排名情况", "查询各厂商在线率排名情况",
{"tool": "online_rate_ranking", "params": {"rate_type": "2"}}, {"tool": "online_rate_ranking", "params": {"rate_type": 2}},
"online_rate_ranking" "online_rate_ranking"
), ),
( (
...@@ -57,6 +57,56 @@ def mock_llm(): ...@@ -57,6 +57,56 @@ def mock_llm():
{"tool": "monitor_points_query", "params": {"key": "甘肃省"}}, {"tool": "monitor_points_query", "params": {"key": "甘肃省"}},
"monitor_points_query" "monitor_points_query"
), ),
(
"查询2024年甘肃省各月在线率",
{"tool": "month_online_rate", "params": {"start_time": "2024-01-01", "end_time": "2024-12-31", "region_name": "甘肃省", "month_required": True}},
"region_online_rate"
),
(
"2024年10月15日,成都市武侯区的设备在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2024-10-15", "end_time": "2024-10-15", "region_name": "成都市武侯区", "month_required": False}},
"region_online_rate"
),
(
"2024年,成都市武侯区的设备在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2024-01-01", "region_name": "成都市武侯区", "month_required": False}},
"region_online_rate"
),
(
"2023年甘肃省每月的设备在线率分别是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "甘肃省", "month_required": True}},
"region_online_rate"
),
(
"2023年甘肃省按月统计设备在线率?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "region_name": "甘肃省", "month_required": True}},
"region_online_rate"
),
(
"2023年全国每个月的设备在线率",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "region_name": "", "month_required": True}},
"region_online_rate"
),
(
"2023年1月-2023年12月期间西藏实验点在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": False}},
"region_online_rate"
),
(
"2023年1月-2023年12月期间西藏实验点各月在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True}},
"region_online_rate"
),
(
"2022年各个月设备在线率统计;",
{"tool": "region_online_rate", "params": {"start_time": "2022-01-01", "end_time": "2022-12-31", "region_name": "", "month_required": True}},
"region_online_rate"
),
(
"2023年1月-2023年12月期间西藏实验点每个月在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True}},
"region_online_rate"
)
]) ])
def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool): def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool):
...@@ -65,13 +115,17 @@ def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool ...@@ -65,13 +115,17 @@ def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool
test_tools = [ test_tools = [
RegionRateTool(), RegionRateTool(),
RankingRateTool(), RankingRateTool(),
MonitorPointTool() MonitorPointTool(),
] ]
picker = ToolPicker(mock_llm, test_tools) picker = ToolPicker(mock_llm, test_tools)
result = picker.pick(query) result = picker.pick(query)
# print(f"query: {query}, result: {result}")
# 验证结果 # 验证结果
assert isinstance(result, dict) assert isinstance(result, dict)
assert result["tool"] == expected_tool assert result["tool"] == expected_tool
assert "params" in result assert "params" in result
for key, value in expected_response["params"].items():
print(f"key: {key}, value: {value} , result_value: {result['params'][key]}")
if key == "month_required" and value == True:
assert result["params"][key] == value
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment