Commit 7ded04d9 by tinywell

工具选择及参数提取测试;补充预警信息获取工具;

parent d1bc8007
......@@ -11,13 +11,17 @@ const_base_url = "http://localhost:5001"
const_url_point = "/cigem/getMonitorPointAll"
const_url_rate = "/cigem/getAvgOnlineRate"
const_url_rate_ranking = "/cigem/getOnlineRateRank"
const_url_rate_month = "/cigem/getAvgOnlineRateOfMonth"
const_url_device_list = "/cigem/getMonitorDeviceList"
const_url_warning = "/cigem/getWarningStatistics"
class BaseResponse(BaseModel, Generic[T]):
"""通用响应模型"""
type: int
resultcode: int
message: str
resultdata: T
resultdata: Optional[T] = None
otherinfo: Optional[str] = None
class BaseHttpClient:
"""基础HTTP客户端"""
......@@ -52,35 +56,75 @@ class MonitorPoint(BaseModel):
ELEVATION: str
BUILDUNIT: str
MONITORUNIT: str
YWUNIT: str
YWUNIT: Optional[str] = None
SGDW: Optional[str] = None
MANUFACTURER: str = ""
MONITORTYPE: str
MANUFACTURER: Optional[str] = None
MONITORTYPE: Optional[str] = None
class Sensor(BaseModel):
"""传感器数据模型"""
SENSORCODE: str
LOCATION: Optional[str] = None
LATITUDE: Optional[str] = None
LONGITUDE: Optional[str] = None
DEVICETYPENAME: Optional[str] = None
MANUFACTURER: Optional[str] = None
class Device(BaseModel):
"""设备数据模型"""
DEVICECODE: str
SN: str
LOCATION: Optional[str] = None
LATITUDE: Optional[str] = None
LONGITUDE: Optional[str] = None
DEVICETYPENAME: Optional[str] = None
MANUFACTURER: Optional[str] = None
MONITORPOINTNAME: Optional[str] = None
SensorList: List[Sensor]
class MonitorClient(BaseHttpClient):
"""监测点查询客户端"""
async def query_points(self, key: str) -> BaseResponse[List[MonitorPoint]]:
async def query_points(self, key: str) -> BaseResponse[List]:
"""异步查询监测点信息"""
data = await self._request_async(
"POST",
const_url_point,
json={"key": key}
)
return BaseResponse[List[MonitorPoint]](**data)
return BaseResponse[List](**data)
def query_points_sync(self, key: str) -> BaseResponse[List[MonitorPoint]]:
def query_points_sync(self, key: str) -> BaseResponse[List]:
"""同步查询监测点信息"""
data = self._request_sync(
"POST",
const_url_point,
json={"key": key}
)
return BaseResponse[List[MonitorPoint]](**data)
return BaseResponse[List](**data)
def query_device_list(self, code: str) -> BaseResponse[List]:
"""同步查询设备列表"""
data = self._request_sync(
"POST",
const_url_device_list,
json={"monitorpointcode": code}
)
print(data)
return BaseResponse[List](**data)
async def query_device_list_async(self, code: str) -> BaseResponse[List]:
"""异步查询设备列表"""
data = await self._request_async(
"POST",
const_url_device_list,
json={"monitorpointcode": code}
)
return BaseResponse[List](**data)
# 示例:添加新的数据接口客户端
class RateClient(BaseHttpClient):
"""在线率查询客户端"""
async def query_rates(self, areacode: str, startDate: str, endDate: str) -> BaseResponse[Dict]:
async def query_rates(self, areacode: str, startDate: str, endDate: str) -> BaseResponse[List]:
"""异步查询在线率信息"""
data = await self._request_async(
"POST",
......@@ -91,7 +135,7 @@ class RateClient(BaseHttpClient):
'endDate': endDate
}
)
return BaseResponse[Dict](**data)
return BaseResponse[List](**data)
def query_rates_sync(self, areacode: str, startDate: str, endDate: str) -> BaseResponse[List]:
"""同步查询在线率信息"""
......@@ -124,28 +168,43 @@ class RateClient(BaseHttpClient):
)
return BaseResponse[List](**data)
# 使用示例
async def example_async_usage():
# 监测点查询示例
monitor_client = MonitorClient()
try:
response = await monitor_client.query_points("湖南")
if response.resultcode == 1 and response.resultdata:
for point in response.resultdata:
print(f"监测点名称: {point.MONITORPOINTNAME}")
print(f"位置: {point.LOCATION}")
print("---")
except httpx.HTTPError as e:
print(f"HTTP请求错误: {e}")
# 在线率查询示例
rate_client = RateClient()
try:
response = await rate_client.query_rates("520100", "2024-01-01", "2024-01-31")
if response.resultcode == 1:
print(f"在线率数据: {response.resultdata}")
except httpx.HTTPError as e:
print(f"HTTP请求错误: {e}")
if __name__ == "__main__":
asyncio.run(example_async_usage())
def query_rates_month_sync(self, year: str, areaCode: str, typeArr: List[int]) -> BaseResponse[List]:
"""同步查询按月度统计的在线率信息"""
data = self._request_sync(
"POST",
const_url_rate_month,
json={
'year': year,
'areaCode': areaCode,
'typeArr': typeArr
}
)
return BaseResponse[List](**data)
class WarningClient(BaseHttpClient):
"""预警查询客户端"""
def query_warning_statistics(self, startDate: str, endDate: str, areaCode: str) -> BaseResponse[List]:
"""同步查询预警统计信息"""
data = self._request_sync(
"POST",
const_url_warning,
json={
'startDate': startDate,
'endDate': endDate,
'areaCode': areaCode
}
)
return BaseResponse[List](**data)
def query_warning_month_statistics(self, year: str, areaCode: str) -> BaseResponse[List]:
"""同步查询按月度统计的预警统计信息"""
data = self._request_sync(
"POST",
const_url_warning,
json={
'year': year,
'areaCode': areaCode
}
)
return BaseResponse[List](**data)
\ No newline at end of file
......@@ -14,15 +14,24 @@ class MonitorPointResponse():
class MonitorPointArgs(BaseModel):
"""监测点查询参数"""
key: str = Field(..., description="行政区划名称(省/市/区县级别均可,只需要最后一级,如岳麓区,不需要长沙市)")
key: str = Field(..., description="行政区划名称(省/市级别均可,只需要最后一级,如长沙市,不需要湖南省)")
year: int = Field(None, description="年份,未提及则为今年")
disaster_type: str = Field(None, description="灾害类型,如崩塌、滑坡、泥石流、地面塌陷、地面沉降、地裂缝等,未提及则为空")
three_d_model: bool = Field(False, description="是否需要三维模型,默认不需要")
ortho_image: bool = Field(False, description="是否需要正射影像,默认不需要")
disaster_threat_people_range_start: int = Field(None, description="灾害威胁人数范围起始值,如100,未提及则为空")
disaster_threat_people_range_end: int = Field(None, description="灾害威胁人数范围结束值,如200,未提及则为空")
disaster_scale: str = Field(None, description="灾害规模,灾害为崩塌、滑坡、泥石流时表示体积,灾害为地面塌陷、地面沉降时表示面积,为地裂缝时表示长度,未提及则为空")
device_type: str = Field(None, description="设备类型(例如 加速度、位移、温度、湿度、裂缝计等),默认为空")
class MonitorPointTool(BaseTool):
"""查询监测点信息的工具"""
name:str = "monitor_points_query"
description:str = """查询指定行政区划的监测点信息。
可以查询任意省/市/区县级别的监测点数据。
可以查询任意省/市/区县级别的监测点数据,也可以通过灾害类型、灾害规模、灾害威胁人数范围、设备类型等条件查询
输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。
返回该区域内的监测点列表,包含位置、经纬度等详细信息。
还可以查询监测点下相关监测设备信息,比如设备数量等。
"""
args_schema: Type[BaseModel] = MonitorPointArgs
client: Any = Field(None, exclude=True)
......@@ -40,18 +49,37 @@ class MonitorPointTool(BaseTool):
self.client = MonitorClient(base_url=base_url)
self.logger = get_logger("MonitorPointTool")
def _run(self, key: str) -> Dict[str, Any]:
def _run(self, key: str, device_required: bool = False, device_type: str = None) -> Dict[str, Any]:
"""
执行监测点查询
Args:
key: 行政区划名称
year: 年份
disaster_type: 灾害类型
three_d_model: 是否需要三维模型
ortho_image: 是否需要正射影像
disaster_threat_people_range_start: 灾害威胁人数范围起始值
disaster_threat_people_range_end: 灾害威胁人数范围结束值
disaster_scale: 灾害规模
device_required: 是否需要设备相关信息
device_type: 设备类型
Returns:
Dict: 包含查询结果的字典
"""
try:
self.logger.info(f"开始查询监测点信息,区域: {key}")
code = ""
if key != "":
self.logger.debug(f"查找区域代码: {key}")
codes = code_tool.find_code(key)
if codes is None or len(codes) == 0:
error_msg = f'未找到匹配的区域代码: {key}'
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
code = codes[0][1]
self.logger.debug(f"找到区域代码: {code}")
response = self.client.query_points_sync(key)
self.logger.debug(f"API响应: {response}")
......@@ -62,7 +90,7 @@ class MonitorPointTool(BaseTool):
'code': 400,
'message': error_msg
}
# 提取关键信息并格式化
points_info = []
for point in response.resultdata:
......@@ -99,19 +127,6 @@ class MonitorPointTool(BaseTool):
'message': error_msg
}
def _arun(self, key: str) -> Dict[str, Any]:
"""
异步执行监测点查询
Args:
key: 行政区划名称
Returns:
Dict: 包含查询结果的字典
"""
self.logger.warning("异步查询方法未实现")
raise NotImplementedError("异步查询暂未实现")
def to_markdown(self, data: List[Dict[str, Any]]) -> str:
"""将数据转换为 markdown 表格"""
self.logger.debug("开始生成 markdown 表格")
......
......@@ -69,6 +69,9 @@ class RegionRateArgs(BaseModel):
region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串")
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
month_statistics: bool = Field(False, description="是否按月度查看一年内的在线率统计结果,默认不需要")
device_items: str = Field(None, description="监测项集合,多个逗号隔开,支持:滑坡仪、设备、传感器、雨量、地表位移、裂缝、倾角、加速度、土壤含水率、泥水位,默认为空")
manufacturer_name: str = Field(None, description="设备厂商,默认为空")
class RegionRateTool(BaseRateTool):
"""查询全国或者特定地区设备在线率的工具"""
......@@ -82,12 +85,12 @@ class RegionRateTool(BaseRateTool):
self.client = RateClient(base_url=base_url)
self.logger.info(f"初始化 RegionRateTool,base_url: {base_url}")
def _run(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]:
return self.get_region_online_rate(start_time, end_time, region_name)
def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]:
return self.get_region_online_rate(start_time, end_time, region_name, month_statistics, device_items, manufacturer_name)
def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]:
def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]:
agent_start = time.time()
self.logger.info(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}")
self.logger.info(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}, 厂商: {manufacturer_name}, 监测项: {device_items}, 按月度查询统计: {month_statistics}")
code = ""
if region_name != "":
......@@ -100,44 +103,58 @@ class RegionRateTool(BaseRateTool):
code = codes[0][1]
self.logger.debug(f"找到区域代码: {code}")
try:
start = time.time()
df = self.client.query_rates_sync(code, start_time, end_time)
query_time = time.time() - start
self.logger.debug(f"API调用耗时: {query_time:.2f}秒")
if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
error_msg = f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
self.logger.warning(error_msg)
if month_statistics: # 按月度查询一年内的统计结果
year = start_time.split('-')[0]
self.logger.debug(f"按月度查询,年份: {year}")
else:
try:
start = time.time()
df = self.client.query_rates_sync(code, start_time, end_time)
query_time = time.time() - start
self.logger.debug(f"API调用耗时: {query_time:.2f}秒")
if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
error_msg = f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
self.logger.debug(f"查询结果: {df.resultdata}")
markdown = self.to_markdown(df.resultdata)
data = {
'region': region_name,
'region_code': code,
'rate_data': df.resultdata,
'markdown': markdown,
'date_range': {
'start': start_time,
'end': end_time
self.logger.debug(f"查询结果: {df.resultdata}")
markdown = self.to_markdown(df.resultdata)
data = {
'region': region_name,
'region_code': code,
'rate_data': df.resultdata,
'markdown': markdown,
'date_range': {
'start': start_time,
'end': end_time
}
}
}
total_time = time.time() - agent_start
self.logger.info(f"查询完成,总耗时: {total_time:.2f}秒")
return data
total_time = time.time() - agent_start
self.logger.info(f"查询完成 {region_name}(code: {code}) 在线率,总耗时: {total_time:.2f}秒")
return data
except Exception as e:
self.logger.error(f"查询失败: {str(e)}", exc_info=True)
raise
except Exception as e:
self.logger.error(f"查询失败: {str(e)}", exc_info=True)
raise
def to_llm(self, region_name: str,start_time: str, end_time: str, data: Dict[str, Any]) -> str:
"""将数据转换为 LLM 可理解的格式"""
self.logger.debug("开始将数据转换为 LLM 可理解的格式")
llm_data = f"共找到{len(data)}条数据,在{start_time}至{end_time}期间,{region_name}的在线率数据如下:\n"
for index, item in enumerate(data):
llm_data += f"{index+1}. {item['name']} 的在线率为 {item['rate']}\n"
return llm_data
def to_markdown(self, data: List[Dict[str, Any]]) -> str:
"""将数据转换为 markdown 表格"""
self.logger.debug("开始生成 markdown 表格")
markdown = """
| 序号 | 日期 | 在线率 |
| 序号 | 地区 | 在线率 |
| --- | --- | --- |
"""
for index, row in enumerate(data):
......@@ -145,8 +162,6 @@ class RegionRateTool(BaseRateTool):
self.logger.debug("markdown 表格生成完成")
return markdown
class RankingRateArgs(BaseModel):
"""排名查询参数"""
rate_type: int = Field(..., description="排序类型,用于指定查询的排名类别。1表示省份排名,2表示厂商排名")
......
"""
不同时间段不同地区预警处置和虚警情况
1.地区处置率数量和占比统计
2.地区虚警率数量和占比统计
3.地区蓝黄橙红数数量和占比统计
"""
from pydantic import BaseModel, Field
from typing import Any, Dict, Type
from .http_tools import WarningClient, const_base_url
class WarningArgs(BaseModel):
"""预警查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串")
month_statistics: bool = Field(False, description="是否按月度查询一年内的统计结果,默认不需要")
class WarningTool(BaseTool):
"""查询预警处置和虚警情况"""
name: str = "warning_statistics"
description: str = "查询不同时间段不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持按月度统计一年内的虚警率,处置率。"
args_schema: Type[BaseModel] = WarningArgs
client: Any = Field(None, exclude=True)
def __init__(self, base_url: str = const_base_url, **data):
super().__init__(**data)
self.client = WarningClient(base_url=base_url)
self.logger.info(f"初始化 WarningTool,base_url: {base_url}")
def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False) -> Dict[str, Any]:
return self.get_warning_statistics(start_time, end_time, region_name, month_statistics)
def get_warning_statistics(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False) -> Dict[str, Any]:
pass
\ No newline at end of file
from langchain_openai import ChatOpenAI
from rich.console import Console
from rich.table import Table
import sys,os
sys.path.append("../")
from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
def run_examples():
# 初始化 rich console
console = Console()
# 初始化 LLM
llm = ChatOpenAI(
openai_api_key="xxxxxx",
openai_api_base="http://192.168.10.14:8000/v1",
model_name="Qwen2-7B",
verbose=True
)
# 初始化工具
tools = [
RegionRateTool(),
RankingRateTool(),
MonitorPointTool(),
]
# 初始化 ToolPicker
picker = ToolPicker(llm, tools)
# 测试案例和预期结果
test_cases = [
{
"query": "请查询下甘肃省的监测点信息",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省",
"three_d_model": False,
"ortho_image": False,
}
}
},{
"query": "查询甘肃省滑坡的监测点信息",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省",
"disaster_type": "滑坡",
}
}
},{
"query": "查询甘肃省有三维模型的监测点信息",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省",
"three_d_model": True,
}
}
},{
"query": "查询甘肃省威胁人口超过30人以上的滑坡的监测点信息",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省",
"disaster_threat_people_range_start": 30,
"disaster_type": "滑坡",
}
}
}
]
# 为每个测试案例创建一个表格
for i, case in enumerate(test_cases, 1):
console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
table = Table(title=f"查询: {case['query']}")
table.add_column("项目", style="cyan")
table.add_column("预期结果", style="green")
table.add_column("实际结果", style="yellow")
table.add_column("是否匹配", style="magenta")
try:
result = picker.pick(case["query"])
# 添加工具比较行
expected_tool = case["expected"]["tool"]
actual_tool = result["tool"]
table.add_row(
"选择的工具",
expected_tool,
actual_tool,
"✓" if expected_tool == actual_tool else "✗"
)
# 添加参数比较行
for param_key in case["expected"]["params"]:
expected_value = str(case["expected"]["params"][param_key])
actual_value = str(result["params"].get(param_key, "未提供"))
table.add_row(
f"参数: {param_key}",
expected_value,
actual_value,
"✓" if expected_value == actual_value else "✗"
)
except Exception as e:
table.add_row("错误", "", str(e), "✗")
console.print(table)
console.print("=" * 80)
if __name__ == "__main__":
run_examples()
\ No newline at end of file
from langchain_openai import ChatOpenAI
from rich.console import Console
from rich.table import Table
import sys,os
sys.path.append("../")
from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
def run_examples():
# 初始化 rich console
console = Console()
# 初始化 LLM
llm = ChatOpenAI(
openai_api_key="xxxxxx",
openai_api_base="http://192.168.10.14:8000/v1",
model_name="Qwen2-7B",
verbose=True
)
# 初始化工具
tools = [
RegionRateTool(),
RankingRateTool(),
MonitorPointTool(),
]
# 初始化 ToolPicker
picker = ToolPicker(llm, tools)
# 测试案例和预期结果
test_cases = [
{
"query": "请分析下今天全国各地区在线率情况",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-11-19",
"end_time": "2024-11-19",
"region_name": "",
"month_required": False
}
}
},
{
"query": "请分析下今天甘肃省设备在线率情况",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-11-19",
"end_time": "2024-11-19",
"region_name": "甘肃省",
"month_required": False
}
}
},
{
"query": "查询2024年11月13日各地区排名情况",
"expected": {
"tool": "online_rate_ranking",
"params": {
"rate_type": 1
}
}
},
{
"query": "查询各厂商在线率排名情况",
"expected": {
"tool": "online_rate_ranking",
"params": {
"rate_type": 2
}
}
},
{
"query": "甘肃省监控点的状态如何?",
"expected": {
"tool": "monitor_points_query",
"params": {
"key": "甘肃省"
}
}
},
{
"query": "2023年甘肃省每月的设备在线率分别是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01",
"end_time": "2023-12-31",
"region_name": "甘肃省",
"month_required": True
}
}
},
{
"query": "查询2024年甘肃省各月在线率",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-01-01",
"end_time": "2024-12-31",
"region_name": "甘肃省",
"month_required": True
}
}
},
{
"query": "2024年10月15日,成都市武侯区的设备在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-10-15",
"end_time": "2024-10-15",
"region_name": "成都市武侯区",
"month_required": False
}
}
},
{
"query": "2024年,成都市武侯区的设备在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2024-01-01", "region_name": "成都市武侯区", "month_required": False
}
}
},
{
"query": "2023年甘肃省每月的设备在线率分别是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "甘肃省", "month_required": True
}
}
},
{
"query": "2023年甘肃省按月统计设备在线率?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "region_name": "甘肃省", "month_required": True
}
}
},
{
"query": "2023年全国每个月的设备在线率",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "region_name": "", "month_required": True
}
}
},
{
"query": "2023年1月-2023年12月期间西藏实验点在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": False
}
}
},
{
"query": "2023年1月-2023年12月期间西藏实验点各月在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True
}
}
},
{
"query": "2022年各个月设备在线率统计;",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2022-01-01", "end_time": "2022-12-31", "region_name": "", "month_required": True
}
}
},
{
"query": "2023年1月-2023年12月期间西藏实验点每个月在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True
}
}
},
{
"query": "2023年1月-2023年12月期间西藏实验点每个月在线率是多少?",
"expected": {
"tool": "region_online_rate",
"params": {
"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True
}
}
}
]
# 为每个测试案例创建一个表格
for i, case in enumerate(test_cases, 1):
console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
table = Table(title=f"查询: {case['query']}")
table.add_column("项目", style="cyan")
table.add_column("预期结果", style="green")
table.add_column("实际结果", style="yellow")
table.add_column("是否匹配", style="magenta")
try:
result = picker.pick(case["query"])
# 添加工具比较行
expected_tool = case["expected"]["tool"]
actual_tool = result["tool"]
table.add_row(
"选择的工具",
expected_tool,
actual_tool,
"✓" if expected_tool == actual_tool else "✗"
)
# 添加参数比较行
for param_key in case["expected"]["params"]:
expected_value = str(case["expected"]["params"][param_key])
actual_value = str(result["params"].get(param_key, "未提供"))
table.add_row(
f"参数: {param_key}",
expected_value,
actual_value,
"✓" if expected_value == actual_value else "✗"
)
except Exception as e:
table.add_row("错误", "", str(e), "✗")
console.print(table)
console.print("=" * 80)
if __name__ == "__main__":
run_examples()
\ No newline at end of file
......@@ -7,7 +7,7 @@ from unittest.mock import Mock
from langchain_openai import ChatOpenAI
from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool,RankingRateTool
from src.agent.tool_rate import RegionRateTool,RankingRateTool,MonthRateTool
from src.agent.tool_monitor import MonitorPointTool
@pytest.fixture
......@@ -29,27 +29,27 @@ def mock_llm():
@pytest.mark.parametrize("query, expected_response, expected_tool", [
(
"请分析下今天全国各地区在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}},
{"tool": "region_online_rate", "params": {"start_time": "2024-11-19", "end_time": "2024-11-19", "region_name": "", "month_required": False}},
"region_online_rate"
),
(
"请分析下今天甘肃省设备在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": "甘肃省"}},
{"tool": "region_online_rate", "params": {"start_time": "2024-11-19", "end_time": "2024-11-19", "region_name": "甘肃省","month_required":False }},
"region_online_rate"
),
(
"查询今年三季度甘肃省设备在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-07-01", "end_time": "2024-09-30", "region_name": "甘肃省"}},
{"tool": "region_online_rate", "params": {"start_time": "2024-07-01", "end_time": "2024-09-30", "region_name": "甘肃省","month_required":False }},
"region_online_rate"
),
(
"查询2024年11月13日各地区排名情况",
{"tool": "online_rate_ranking", "params": {"rate_type": "1"}},
{"tool": "online_rate_ranking", "params": {"rate_type": 1}},
"online_rate_ranking"
),
(
"查询各厂商在线率排名情况",
{"tool": "online_rate_ranking", "params": {"rate_type": "2"}},
{"tool": "online_rate_ranking", "params": {"rate_type": 2}},
"online_rate_ranking"
),
(
......@@ -57,6 +57,56 @@ def mock_llm():
{"tool": "monitor_points_query", "params": {"key": "甘肃省"}},
"monitor_points_query"
),
(
"查询2024年甘肃省各月在线率",
{"tool": "month_online_rate", "params": {"start_time": "2024-01-01", "end_time": "2024-12-31", "region_name": "甘肃省", "month_required": True}},
"region_online_rate"
),
(
"2024年10月15日,成都市武侯区的设备在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2024-10-15", "end_time": "2024-10-15", "region_name": "成都市武侯区", "month_required": False}},
"region_online_rate"
),
(
"2024年,成都市武侯区的设备在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2024-01-01", "region_name": "成都市武侯区", "month_required": False}},
"region_online_rate"
),
(
"2023年甘肃省每月的设备在线率分别是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "甘肃省", "month_required": True}},
"region_online_rate"
),
(
"2023年甘肃省按月统计设备在线率?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "region_name": "甘肃省", "month_required": True}},
"region_online_rate"
),
(
"2023年全国每个月的设备在线率",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "region_name": "", "month_required": True}},
"region_online_rate"
),
(
"2023年1月-2023年12月期间西藏实验点在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": False}},
"region_online_rate"
),
(
"2023年1月-2023年12月期间西藏实验点各月在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True}},
"region_online_rate"
),
(
"2022年各个月设备在线率统计;",
{"tool": "region_online_rate", "params": {"start_time": "2022-01-01", "end_time": "2022-12-31", "region_name": "", "month_required": True}},
"region_online_rate"
),
(
"2023年1月-2023年12月期间西藏实验点每个月在线率是多少?",
{"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True}},
"region_online_rate"
)
])
def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool):
......@@ -65,13 +115,17 @@ def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool
test_tools = [
RegionRateTool(),
RankingRateTool(),
MonitorPointTool()
MonitorPointTool(),
]
picker = ToolPicker(mock_llm, test_tools)
result = picker.pick(query)
# print(f"query: {query}, result: {result}")
# 验证结果
assert isinstance(result, dict)
assert result["tool"] == expected_tool
assert "params" in result
for key, value in expected_response["params"].items():
print(f"key: {key}, value: {value} , result_value: {result['params'][key]}")
if key == "month_required" and value == True:
assert result["params"][key] == value
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment