import time from typing import Dict, List, Tuple, Any, Optional from pydantic import BaseModel, Field from typing import Type from langchain_core.tools import BaseTool from .http_tools import RateClient, const_base_url from .code import AreaCodeTool code_tool = AreaCodeTool() class BaseRateTool(BaseTool): """设备在线率分析基础工具类""" def __init__(self, **data): super().__init__(**data) def format_response(self, data: Dict[str, Any], chart: Any) -> Dict[str, Any]: """格式化返回结果""" return { 'data': data, # 'chart': chart, 'summary': self._generate_summary(data) } def _generate_summary(self, data: Dict[str, Any]) -> str: """生成数据分析总结文本""" # 保持原有的summary生成逻辑 if 'trend_data' in data: # 全国趋势数据 return ( f"在{data['date_range']['start']}至{data['date_range']['end']}期间," f"全国平均在线率为{data['statistics']['average_rate']:.2%}," f"最高达到{data['statistics']['max_rate']:.2%}," f"最低为{data['statistics']['min_rate']:.2%}。" f"平均设备总数约{data['statistics']['average_devices']:,}台。" ) elif 'rankings' in data: # 排名数据 if 'total_provinces' in data: # 省份排名 return ( f"共分析了{data['total_provinces']}个省份的在线率数据," f"平均在线率为{data['average_rate']:.2%}。" f"{data['best_province']['name']}的表现最好," f"在线率达到{data['best_province']['rate']:.2%}。" ) else: # 厂商排名 return ( f"共分析了{data['total_manufacturers']}个厂商的在线率数据," f"平均在线率为{data['average_rate']:.2%}。" f"{data['best_manufacturer']['name']}的表现最好," f"在线率达到{data['best_manufacturer']['rate']:.2%}。" ) else: # 地区在线率数据 return ( f"{data['region']}在{data['date_range']['start']}至{data['date_range']['end']}期间," f"平均在线率为{data['average_rate']:.2%}," f"最高达到{data['max_rate']:.2%}," f"最低为{data['min_rate']:.2%}。" f"平均设备数约{int(data['total_devices']):,}台。" ) class RegionRateArgs(BaseModel): """地区在线率查询参数""" region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串") start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)") end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)") class RegionRateTool(BaseRateTool): """查询全国或者特定地区设备在线率的工具""" name:str = "region_online_rate" description:str = "查询指定地区在指定时间段内的设备在线率" args_schema: Type[BaseModel] = RegionRateArgs client: Any = Field(None, exclude=True) def __init__(self, base_url: str = const_base_url, **data): super().__init__(**data) self.client = RateClient(base_url=base_url) def _run(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]: return self.get_region_online_rate(start_time, end_time, region_name) def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]: # 查询数据 agent_start = time.time() print(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}") code = "" if region_name != "": codes = code_tool.find_code(region_name) if codes is None or len(codes) == 0: return { 'code': 400, 'message': f'未找到匹配的区域代码: {region_name}' } code = codes[0][1] print(code) start = time.time() df = self.client.query_rates_sync(code, start_time, end_time) end = time.time() print(f"query_rates_sync client spent time:{end-start}") print(f"地区在线率接口调用结果: {df}") # 准备数据 if df.type != 1 or len(df.resultdata) == 0: return { 'code': 400, 'message': f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限' } print(f"地区在线率查询结果: {df.resultdata}") data = { 'region': region_name, 'region_code': code, 'rate_data': df.resultdata, 'date_range': { 'start': start_time, 'end': end_time } } end = time.time() print(f"once agent spent time:{end - agent_start}") return data class RankingRateArgs(BaseModel): """排名查询参数""" rate_type: int = Field(..., description="排序类型,用于指定查询的排名类别。1表示省份排名,2表示厂商排名") class RankingRateTool(BaseRateTool): """查询在线率排名的工具""" name:str = "online_rate_ranking" description:str = "查询设备在线率的排名数据,可查询省份排名或厂商排名" args_schema: Type[BaseModel] = RankingRateArgs client: Any = Field(None, exclude=True) def __init__(self, base_url: str = const_base_url, **data): super().__init__(**data) self.client = RateClient(base_url=base_url) def _run(self, rate_type: int) -> Dict[str, Any]: return self.get_ranking_data(rate_type) def get_ranking_data(self, rate_type: int) -> Dict[str, Any]: if rate_type == 1: return self._get_province_ranking() else: return self._get_manufacturer_ranking() def _get_province_ranking(self) -> Dict[str, Any]: """获取省份在线率排名""" df = self.client.query_rates_ranking_sync(rank_type=1) # df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10] # 准备数据 if df.type != 1 or len(df.resultdata) == 0: return { 'code': 400, 'message': f'未找到省份在线率排名数据,请检查是否有相关数据权限' } print(f"省份在线率排名数据: {df.resultdata}") data = { 'rankings': df.resultdata, 'total_provinces': len(df.resultdata), 'best_province': { 'name': df.resultdata[0]['name'], 'rate': df.resultdata[0]['onlineRate'] } } return data def _get_manufacturer_ranking(self) -> Dict[str, Any]: """获取厂商在线率排名""" df = self.client.query_rates_ranking_sync(rank_type=2) # df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10] # 准备数据 if df.type != 1 or len(df.resultdata) == 0: return { 'code': 400, 'message': f'未找到厂商在线率排名数据,请检查是否有相关数据权限' } print(f"厂商在线率排名数据: {df.resultdata}") data = { 'rankings': df.resultdata, 'total_manufacturers': len(df.resultdata), 'best_manufacturer': { 'name': df.resultdata[0]['name'], 'rate': df.resultdata[0]['onlineRate'] } } return data