diff --git a/src/agent/tool_rate.py b/src/agent/tool_rate.py index 9d57bbe..8b8c279 100644 --- a/src/agent/tool_rate.py +++ b/src/agent/tool_rate.py @@ -66,12 +66,12 @@ class BaseRateTool(BaseTool): class RegionRateArgs(BaseModel): """地区在线率查询参数""" - region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串") - start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)") - end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)") + region_name: str = Field("", description="地区名称,如果要查询全国数据,请输入空字符串") + start_time: str = Field("", description="开始时间 (YYYY-MM-DD)") + end_time: str = Field("", description="结束时间 (YYYY-MM-DD)") month_statistics: bool = Field(False, description="是否按月度查看一年内的在线率统计结果,默认不需要") - device_items: str = Field(None, description="监测项集合,多个逗号隔开,支持:滑坡仪、设备、传感器、雨量、地表位移、裂缝、倾角、加速度、土壤含水率、泥水位,默认为空") - manufacturer_name: str = Field(None, description="设备厂商,默认为空") + device_items: str = Field("", description="监测项集合,多个逗号隔开,支持:滑坡仪、设备、传感器、雨量、地表位移、裂缝、倾角、加速度、土壤含水率、泥水位,默认为空") + manufacturer_name: str = Field("", description="设备厂商,默认为空") class RegionRateTool(BaseRateTool): """查询全国或者特定地区设备在线率的工具""" @@ -85,13 +85,7 @@ class RegionRateTool(BaseRateTool): self.client = RateClient(base_url=base_url) self.logger.info(f"初始化 RegionRateTool,base_url: {base_url}") - def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]: - return self.get_region_online_rate(start_time, end_time, region_name, month_statistics, device_items, manufacturer_name) - - def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]: - agent_start = time.time() - self.logger.info(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}, 厂商: {manufacturer_name}, 监测项: {device_items}, 按月度查询统计: {month_statistics}") - + def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str="", manufacturer_name: str="") -> Dict[str, Any]: code = "" if region_name != "": self.logger.debug(f"查找区域代码: {region_name}") @@ -103,43 +97,196 @@ class RegionRateTool(BaseRateTool): code = codes[0][1] self.logger.debug(f"找到区域代码: {code}") - if month_statistics: # 按月度查询一年内的统计结果 - year = start_time.split('-')[0] - self.logger.debug(f"按月度查询,年份: {year}") - + if month_statistics: + return self.get_region_online_rate_of_month(region_name, code, start_time, end_time) else: - try: - start = time.time() - df = self.client.query_rates_sync(code, start_time, end_time) - query_time = time.time() - start - self.logger.debug(f"API调用耗时: {query_time:.2f}秒") + return self.get_region_online_rate(start_time, end_time, region_name, code, month_statistics, device_items, manufacturer_name) + + def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="",code: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]: + agent_start = time.time() + self.logger.info(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}, 厂商: {manufacturer_name}, 监测项: {device_items}, 按月度查询统计: {month_statistics}") + + try: + start = time.time() + df = self.client.query_rates_sync(code, start_time, end_time, manufacturer_name, device_items) + query_time = time.time() - start + self.logger.debug(f"API调用耗时: {query_time:.2f}秒") - if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0: - error_msg = f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限' - self.logger.warning(error_msg) + if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0: + error_msg = f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限' + self.logger.warning(error_msg) return {'code': 400, 'message': error_msg} - self.logger.debug(f"查询结果: {df.resultdata}") - markdown = self.to_markdown(df.resultdata) + self.logger.debug(f"查询结果: {df.resultdata}") + # markdown = self.to_markdown(df.resultdata) + + result_data = [] + for item in df.resultdata: + rate_date = self._extract_rate_data(item) + result_data.append(rate_date) - data = { + data = { 'region': region_name, 'region_code': code, - 'rate_data': df.resultdata, - 'markdown': markdown, - 'date_range': { - 'start': start_time, - 'end': end_time - } - } + 'rate_data': result_data, + } - total_time = time.time() - agent_start - self.logger.info(f"查询完成 {region_name}(code: {code}) 在线率,总耗时: {total_time:.2f}秒") - return data + total_time = time.time() - agent_start + self.logger.info(f"查询完成 {region_name}(code: {code}) 在线率,总耗时: {total_time:.2f}秒") + return data - except Exception as e: - self.logger.error(f"查询失败: {str(e)}", exc_info=True) - raise + except Exception as e: + self.logger.error(f"查询失败: {str(e)}", exc_info=True) + raise + + def get_region_online_rate_of_month(self, region_name: str, code: str, start_time: str, end_time: str,device_items: str=None) -> Dict[str, Any]: + """查询指定地区在指定时间段内的月度在线率""" + year = start_time.split('-')[0] + self.logger.debug(f"按月度查询,年份: {year}") + try: + df = self.client.query_rates_month_sync(year, code, device_items) + + if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0: + error_msg = f'未找到{region_name}(code: {code})在{year}年内的数据,请检查是否有相关数据权限' + self.logger.warning(error_msg) + return {'code': 400, 'message': error_msg} + + result_data = {} + for index, item in enumerate(df.resultdata): + month_data = self._extract_rate_data(item) + result_data[item['month']] = month_data + # 排序 + + self.logger.debug(f"查询结果: {df.resultdata}") + # markdown = self.to_markdown(df.resultdata) + + data = { + 'region': region_name, + 'region_code': code, + 'rate_data': result_data + } + + return data + + except Exception as e: + self.logger.error(f"查询失败: {str(e)}", exc_info=True) + raise + + def _extract_rate_data(self, item: Dict[str, Any]) -> Dict[str, Any]: + """提取月度数据""" + rate_date = {} + if item.get("name") is not None: + rate_date['名称'] = item.get("name") + if item.get("rate") is not None: + rate_date['在线率'] = item.get("rate") + if item.get("rtuCount") is not None: + rate_date['在线数量'] = item.get("rtuCount") + if item.get("monitorRate") is not None: + rate_date['滑坡仪在线率'] = item.get("monitorRate") + if item.get("sensorRate") is not None: + rate_date['传感器在线率'] = item.get("sensorRate") + if item.get("monitorCount") is not None: + rate_date['滑坡仪数量'] = item.get("monitorCount") + if item.get("sensorCount") is not None: + rate_date['传感器数量'] = item.get("sensorCount") + if item.get("lfRate") is not None: + rate_date['裂缝在线率'] = item.get("lfRate") + if item.get("lfCount") is not None: + rate_date['裂缝数量'] = item.get("lfCount") + if item.get("gpRate") is not None: + rate_date['地表位移在线率'] = item.get("gpRate") + if item.get("gpCount") is not None: + rate_date['地表位移数量'] = item.get("gpCount") + if item.get("swRate") is not None: + rate_date['深部位移在线率'] = item.get("swRate") + if item.get("swCount") is not None: + rate_date['深部位移数量'] = item.get("swCount") + if item.get("jsRate") is not None: + rate_date['加速度在线率'] = item.get("jsRate") + if item.get("jsCount") is not None: + rate_date['加速度数量'] = item.get("jsCount") + if item.get("qjRate") is not None: + rate_date['倾角在线率'] = item.get("qjRate") + if item.get("qjCount") is not None: + rate_date['倾角数量'] = item.get("qjCount") + if item.get("zdRate") is not None: + rate_date['振动在线率'] = item.get("zdRate") + if item.get("zdCount") is not None: + rate_date['振动数量'] = item.get("zdCount") + if item.get("ylRate") is not None: + rate_date['应力在线率'] = item.get("ylRate") + if item.get("ylCount") is not None: + rate_date['应力数量'] = item.get("ylCount") + if item.get("tyRate") is not None: + rate_date['土压力在线率'] = item.get("tyRate") + if item.get("tyCount") is not None: + rate_date['土压力数量'] = item.get("tyCount") + if item.get("csRate") is not None: + rate_date['次声在线率'] = item.get("csRate") + if item.get("csCount") is not None: + rate_date['次声数量'] = item.get("csCount") + if item.get("dsRate") is not None: + rate_date['地声在线率'] = item.get("dsRate") + if item.get("dsCount") is not None: + rate_date['地声数量'] = item.get("dsCount") + if item.get("ylRate") is not None: + rate_date['雨量在线率'] = item.get("ylRate") + if item.get("ylCount") is not None: + rate_date['雨量数量'] = item.get("ylCount") + if item.get("qwRate") is not None: + rate_date['气温在线率'] = item.get("qwRate") + if item.get("qwCount") is not None: + rate_date['气温数量'] = item.get("qwCount") + if item.get("twRate") is not None: + rate_date['土壤湿度在线率'] = item.get("twRate") + if item.get("twCount") is not None: + rate_date['土壤湿度数量'] = item.get("twCount") + if item.get("hsRate") is not None: + rate_date['土壤含水率在线率'] = item.get("hsRate") + if item.get("hsCount") is not None: + rate_date['土壤含水率数量'] = item.get("hsCount") + if item.get("dbRate") is not None: + rate_date['地表水温/水位在线率'] = item.get("dbRate") + if item.get("dbCount") is not None: + rate_date['地表水温/水位数量'] = item.get("dbCount") + if item.get("syRate") is not None: + rate_date['孔隙水温/水压在线率'] = item.get("syRate") + if item.get("syCount") is not None: + rate_date['孔隙水温/水压数量'] = item.get("syCount") + if item.get("stRate") is not None: + rate_date['渗透压力在线率'] = item.get("stRate") + if item.get("stCount") is not None: + rate_date['渗透压力数量'] = item.get("stCount") + if item.get("lsRate") is not None: + rate_date['流速在线率'] = item.get("lsRate") + if item.get("lsCount") is not None: + rate_date['流速数量'] = item.get("lsCount") + if item.get("cjRate") is not None: + rate_date['沉降在线率'] = item.get("cjRate") + if item.get("cjCount") is not None: + rate_date['沉降数量'] = item.get("cjCount") + if item.get("qyRate") is not None: + rate_date['气压在线率'] = item.get("qyRate") + if item.get("qyCount") is not None: + rate_date['气压数量'] = item.get("qyCount") + if item.get("spRate") is not None: + rate_date['视频在线率'] = item.get("spRate") + if item.get("spCount") is not None: + rate_date['视频数量'] = item.get("spCount") + if item.get("nwRate") is not None: + rate_date['泥水位在线率'] = item.get("nwRate") + if item.get("nwCount") is not None: + rate_date['泥水位数量'] = item.get("nwCount") + if item.get("ldRate") is not None: + rate_date['雷达在线率'] = item.get("ldRate") + if item.get("ldCount") is not None: + rate_date['雷达数量'] = item.get("ldCount") + if item.get("lbRate") is not None: + rate_date['预警喇叭在线率'] = item.get("lbRate") + if item.get("lbCount") is not None: + rate_date['预警喇叭数量'] = item.get("lbCount") + return rate_date + def to_llm(self, region_name: str,start_time: str, end_time: str, data: Dict[str, Any]) -> str: """将数据转换为 LLM 可理解的格式""" self.logger.debug("开始将数据转换为 LLM 可理解的格式") diff --git a/test/run_tool_picker_monitor.py b/test/run_tool_picker_monitor.py index e4071d7..e78c5db 100644 --- a/test/run_tool_picker_monitor.py +++ b/test/run_tool_picker_monitor.py @@ -75,7 +75,15 @@ def run_examples(): "disaster_type": "滑坡", } } - } + },{ + "query": "甘肃省监控点的状态如何?", + "expected": { + "tool": "monitor_points_query", + "params": { + "key": "甘肃省陇南市" + } + } + }, ] # 为每个测试案例创建一个表格 diff --git a/test/run_tool_picker_rate.py b/test/run_tool_picker_rate.py index 4525497..986f271 100644 --- a/test/run_tool_picker_rate.py +++ b/test/run_tool_picker_rate.py @@ -21,14 +21,15 @@ def run_examples(): model_name="Qwen2-7B", verbose=True ) - + base_url = "http://172.30.0.37:30007" # 初始化工具 tools = [ - RegionRateTool(), - RankingRateTool(), - MonitorPointTool(), + RegionRateTool(base_url=base_url), + RankingRateTool(base_url=base_url), + MonitorPointTool(base_url=base_url), ] + tool_dict = {tool.name: tool for tool in tools} # 初始化 ToolPicker picker = ToolPicker(llm, tools) @@ -42,7 +43,7 @@ def run_examples(): "start_time": "2024-11-19", "end_time": "2024-11-19", "region_name": "", - "month_required": False + "month_statistics": False } } }, @@ -54,7 +55,7 @@ def run_examples(): "start_time": "2024-11-19", "end_time": "2024-11-19", "region_name": "甘肃省", - "month_required": False + "month_statistics": False } } }, @@ -77,15 +78,6 @@ def run_examples(): } }, { - "query": "甘肃省监控点的状态如何?", - "expected": { - "tool": "monitor_points_query", - "params": { - "key": "甘肃省" - } - } - }, - { "query": "2023年甘肃省每月的设备在线率分别是多少?", "expected": { "tool": "region_online_rate", @@ -93,7 +85,7 @@ def run_examples(): "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "甘肃省", - "month_required": True + "month_statistics": True } } }, @@ -105,28 +97,28 @@ def run_examples(): "start_time": "2024-01-01", "end_time": "2024-12-31", "region_name": "甘肃省", - "month_required": True + "month_statistics": True } } }, { - "query": "2024年10月15日,成都市武侯区的设备在线率是多少?", + "query": "2024年10月15日,兰州市的设备在线率是多少?", "expected": { "tool": "region_online_rate", "params": { "start_time": "2024-10-15", "end_time": "2024-10-15", - "region_name": "成都市武侯区", - "month_required": False + "region_name": "兰州市", + "month_statistics": False } } }, { - "query": "2024年,成都市武侯区的设备在线率是多少?", + "query": "2024年,兰州市榆中县的设备在线率是多少?", "expected": { "tool": "region_online_rate", "params": { - "start_time": "2024-01-01", "region_name": "成都市武侯区", "month_required": False + "start_time": "2024-11-26", "region_name": "兰州市榆中县", "month_statistics": False } } }, @@ -135,7 +127,7 @@ def run_examples(): "expected": { "tool": "region_online_rate", "params": { - "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "甘肃省", "month_required": True + "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "甘肃省", "month_statistics": True } } }, @@ -144,7 +136,7 @@ def run_examples(): "expected": { "tool": "region_online_rate", "params": { - "start_time": "2023-01-01", "region_name": "甘肃省", "month_required": True + "start_time": "2023-01-01", "region_name": "甘肃省", "month_statistics": True } } }, @@ -153,25 +145,25 @@ def run_examples(): "expected": { "tool": "region_online_rate", "params": { - "start_time": "2023-01-01", "region_name": "", "month_required": True + "start_time": "2023-01-01", "region_name": "", "month_statistics": True } } }, { - "query": "2023年1月-2023年12月期间西藏实验点在线率是多少?", + "query": "2023年1月-2023年12月期间青海实验点在线率是多少?", "expected": { "tool": "region_online_rate", "params": { - "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": False + "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "青海", "month_statistics": False } } }, { - "query": "2023年1月-2023年12月期间西藏实验点各月在线率是多少?", + "query": "2023年1月-2023年12月期间青海实验点各月在线率是多少?", "expected": { "tool": "region_online_rate", "params": { - "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True + "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "青海", "month_statistics": True } } }, @@ -180,7 +172,7 @@ def run_examples(): "expected": { "tool": "region_online_rate", "params": { - "start_time": "2022-01-01", "end_time": "2022-12-31", "region_name": "", "month_required": True + "start_time": "2022-01-01", "end_time": "2022-12-31", "region_name": "", "month_statistics": True } } }, @@ -189,7 +181,7 @@ def run_examples(): "expected": { "tool": "region_online_rate", "params": { - "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True + "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_statistics": True } } }, @@ -198,7 +190,7 @@ def run_examples(): "expected": { "tool": "region_online_rate", "params": { - "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True + "start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_statistics": True } } } @@ -237,7 +229,13 @@ def run_examples(): actual_value, "✓" if expected_value == actual_value else "✗" ) - + + tool = tool_dict[result["tool"]] + params = result["params"] + + result = tool.invoke(params) + print(result) + except Exception as e: table.add_row("错误", "", str(e), "✗")