tool_rate.py 20.2 KB
Newer Older
文靖昊 committed
1
import time
tinywell committed
2
import logging
3 4 5 6 7
from typing import Dict, List, Tuple, Any, Optional
from pydantic import BaseModel, Field
from typing import Type
from langchain_core.tools import BaseTool

tinywell committed
8
from .http_tools import RateClient, const_base_url
9
from .code import AreaCodeTool
10
from ..utils.logger import get_logger
11 12

code_tool = AreaCodeTool()
13

14 15
class BaseRateTool(BaseTool):
    """设备在线率分析基础工具类"""
16
    logger: logging.Logger = Field(None, exclude=True)
17
    
tinywell committed
18
    def __init__(self, **data):
19
        super().__init__(**data)
20
        self.logger = get_logger(self.__class__.__name__)
21
        
tinywell committed
22
    def format_response(self, data: Dict[str, Any], chart: Any) -> Dict[str, Any]:
23
        """格式化返回结果"""
24
        self.logger.debug("格式化返回结果")
25 26 27 28 29 30 31
        return {
            'data': data,
            'summary': self._generate_summary(data)
        }
    
    def _generate_summary(self, data: Dict[str, Any]) -> str:
        """生成数据分析总结文本"""
32
        self.logger.debug("生成数据分析总结")
33
        if 'trend_data' in data:  # 全国趋势数据
34
            summary = (
35 36 37 38 39 40 41 42
                f"在{data['date_range']['start']}至{data['date_range']['end']}期间,"
                f"全国平均在线率为{data['statistics']['average_rate']:.2%},"
                f"最高达到{data['statistics']['max_rate']:.2%},"
                f"最低为{data['statistics']['min_rate']:.2%}。"
                f"平均设备总数约{data['statistics']['average_devices']:,}台。"
            )
        elif 'rankings' in data:  # 排名数据
            if 'total_provinces' in data:  # 省份排名
43
                summary = (
44 45 46 47 48 49
                    f"共分析了{data['total_provinces']}个省份的在线率数据,"
                    f"平均在线率为{data['average_rate']:.2%}。"
                    f"{data['best_province']['name']}的表现最好,"
                    f"在线率达到{data['best_province']['rate']:.2%}。"
                )
            else:  # 厂商排名
50
                summary = (
51 52 53 54 55 56
                    f"共分析了{data['total_manufacturers']}个厂商的在线率数据,"
                    f"平均在线率为{data['average_rate']:.2%}。"
                    f"{data['best_manufacturer']['name']}的表现最好,"
                    f"在线率达到{data['best_manufacturer']['rate']:.2%}。"
                )
        else:  # 地区在线率数据
57
            summary = (
58 59 60 61 62 63
                f"{data['region']}在{data['date_range']['start']}至{data['date_range']['end']}期间,"
                f"平均在线率为{data['average_rate']:.2%},"
                f"最高达到{data['max_rate']:.2%},"
                f"最低为{data['min_rate']:.2%}。"
                f"平均设备数约{int(data['total_devices']):,}台。"
            )
64 65
        self.logger.debug(f"生成的总结: {summary}")
        return summary
66 67 68

class RegionRateArgs(BaseModel):
    """地区在线率查询参数"""
69 70 71
    region_name: str = Field("", description="地区名称,如果要查询全国数据,请输入空字符串")
    start_time: str = Field("", description="开始时间 (YYYY-MM-DD)")
    end_time: str = Field("", description="结束时间 (YYYY-MM-DD)")
72
    month_statistics: bool = Field(False, description="是否按月度查看一年内的在线率统计结果,默认不需要")
73 74
    device_items: str = Field("", description="监测项集合,多个逗号隔开,支持:滑坡仪、设备、传感器、雨量、地表位移、裂缝、倾角、加速度、土壤含水率、泥水位,默认为空")
    manufacturer_name: str = Field("", description="设备厂商,默认为空")
75 76

class RegionRateTool(BaseRateTool):
tinywell committed
77
    """查询全国或者特定地区设备在线率的工具"""
文靖昊 committed
78 79
    name:str = "region_online_rate"
    description:str = "查询指定地区在指定时间段内的设备在线率"
80
    args_schema: Type[BaseModel] = RegionRateArgs
81 82
    client: Any = Field(None, exclude=True)
    
tinywell committed
83
    def __init__(self, base_url: str = const_base_url, **data):
84 85
        super().__init__(**data)
        self.client = RateClient(base_url=base_url)
86
        self.logger.info(f"初始化 RegionRateTool,base_url: {base_url}")
87
    
88
    def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str="", manufacturer_name: str="") -> Dict[str, Any]:
tinywell committed
89 90
        code = ""
        if region_name != "":
91
            self.logger.debug(f"查找区域代码: {region_name}")
tinywell committed
92 93
            codes = code_tool.find_code(region_name)
            if codes is None or len(codes) == 0:
94 95 96
                error_msg = f'未找到匹配的区域代码: {region_name}'
                self.logger.warning(error_msg)
                return {'code': 400, 'message': error_msg}
tinywell committed
97
            code = codes[0][1]
98 99
            self.logger.debug(f"找到区域代码: {code}")

100 101
        if month_statistics:
            return self.get_region_online_rate_of_month(region_name, code, start_time, end_time) 
102
        else:
103 104 105 106 107 108 109 110 111 112 113
            return self.get_region_online_rate(start_time, end_time, region_name, code, month_statistics, device_items, manufacturer_name)
        
    def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="",code: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]:
        agent_start = time.time()
        self.logger.info(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}, 厂商: {manufacturer_name}, 监测项: {device_items}, 按月度查询统计: {month_statistics}")
        
        try:
            start = time.time()
            df = self.client.query_rates_sync(code, start_time, end_time, manufacturer_name, device_items)
            query_time = time.time() - start
            self.logger.debug(f"API调用耗时: {query_time:.2f}秒")
114
                
115 116 117
            if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
                error_msg = f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
                self.logger.warning(error_msg)
118 119
                return {'code': 400, 'message': error_msg}
            
120 121 122 123 124 125 126
            self.logger.debug(f"查询结果: {df.resultdata}")
            # markdown = self.to_markdown(df.resultdata)
            
            result_data = []
            for item in df.resultdata:
                rate_date = self._extract_rate_data(item)
                result_data.append(rate_date)
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            sorted_data = [rate_date[0]]
            if len(result_data) > 1:
                sorted_data.extend(sorted(result_data[1:], key=lambda x: x['在线率']))
            
            # 将数据改成表格形式,表头提取之后,数据一行行显示
            table_header = list(result_data[0].keys())
            table_data = []
            for item in result_data:
                table_data.append(list(item.values()))
            
            result_data = {
                'table_header': table_header,
                'table_data': table_data
            }
            
142
            data = {
143 144 145
                'region': region_name,
                'region_code': code,
                'rate_data': result_data,
146
            }
147
            
148 149 150
            total_time = time.time() - agent_start
            self.logger.info(f"查询完成 {region_name}(code: {code}) 在线率,总耗时: {total_time:.2f}秒")
            return data
151
            
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
        except Exception as e:
            self.logger.error(f"查询失败: {str(e)}", exc_info=True)
            raise

    def get_region_online_rate_of_month(self, region_name: str, code: str, start_time: str, end_time: str,device_items: str=None) -> Dict[str, Any]:
        """查询指定地区在指定时间段内的月度在线率"""
        year = start_time.split('-')[0]
        self.logger.debug(f"按月度查询,年份: {year}")  
        try:    
            df = self.client.query_rates_month_sync(year, code, device_items)

            if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
                error_msg = f'未找到{region_name}(code: {code})在{year}年内的数据,请检查是否有相关数据权限'
                self.logger.warning(error_msg)
                return {'code': 400, 'message': error_msg}

            result_data = {}
            for index, item in enumerate(df.resultdata):
                month_data = self._extract_rate_data(item)
                result_data[item['month']] = month_data
            # 排序
tinywell committed
173 174
            result_data = sorted(result_data.items(), key=lambda x: x[0])

175 176
            self.logger.debug(f"查询结果: {df.resultdata}")
            # markdown = self.to_markdown(df.resultdata)
177 178 179 180 181 182 183 184 185 186
            
            table_header = list(result_data[0].keys())
            table_data = []
            for item in result_data:
                table_data.append(list(item.values()))
            
            result_data = {
                'table_header': table_header,
                'table_data': table_data
            }
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314

            data = {
                'region': region_name,
                'region_code': code,
                'rate_data': result_data
            }

            return data

        except Exception as e:
            self.logger.error(f"查询失败: {str(e)}", exc_info=True)
            raise

    def _extract_rate_data(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """提取月度数据"""
        rate_date = {}
        if item.get("name") is not None:
            rate_date['名称'] = item.get("name")
        if item.get("rate") is not None:
            rate_date['在线率'] = item.get("rate")
        if item.get("rtuCount") is not None:
            rate_date['在线数量'] = item.get("rtuCount")
        if item.get("monitorRate") is not None:
            rate_date['滑坡仪在线率'] = item.get("monitorRate")
        if item.get("sensorRate") is not None:
            rate_date['传感器在线率'] = item.get("sensorRate")
        if item.get("monitorCount") is not None:
            rate_date['滑坡仪数量'] = item.get("monitorCount")
        if item.get("sensorCount") is not None:
            rate_date['传感器数量'] = item.get("sensorCount")
        if item.get("lfRate") is not None:
            rate_date['裂缝在线率'] = item.get("lfRate")
        if item.get("lfCount") is not None:
            rate_date['裂缝数量'] = item.get("lfCount")
        if item.get("gpRate") is not None:
            rate_date['地表位移在线率'] = item.get("gpRate")
        if item.get("gpCount") is not None:
            rate_date['地表位移数量'] = item.get("gpCount")
        if item.get("swRate") is not None:
            rate_date['深部位移在线率'] = item.get("swRate")
        if item.get("swCount") is not None:
            rate_date['深部位移数量'] = item.get("swCount")
        if item.get("jsRate") is not None:
            rate_date['加速度在线率'] = item.get("jsRate")
        if item.get("jsCount") is not None:
            rate_date['加速度数量'] = item.get("jsCount")
        if item.get("qjRate") is not None:
            rate_date['倾角在线率'] = item.get("qjRate")
        if item.get("qjCount") is not None:
            rate_date['倾角数量'] = item.get("qjCount")
        if item.get("zdRate") is not None:
            rate_date['振动在线率'] = item.get("zdRate")
        if item.get("zdCount") is not None:
            rate_date['振动数量'] = item.get("zdCount")
        if item.get("ylRate") is not None:
            rate_date['应力在线率'] = item.get("ylRate")
        if item.get("ylCount") is not None:
            rate_date['应力数量'] = item.get("ylCount")
        if item.get("tyRate") is not None:
            rate_date['土压力在线率'] = item.get("tyRate")
        if item.get("tyCount") is not None:
            rate_date['土压力数量'] = item.get("tyCount")
        if item.get("csRate") is not None:
            rate_date['次声在线率'] = item.get("csRate")
        if item.get("csCount") is not None:
            rate_date['次声数量'] = item.get("csCount")
        if item.get("dsRate") is not None:
            rate_date['地声在线率'] = item.get("dsRate")
        if item.get("dsCount") is not None:
            rate_date['地声数量'] = item.get("dsCount")
        if item.get("ylRate") is not None:
            rate_date['雨量在线率'] = item.get("ylRate")
        if item.get("ylCount") is not None:
            rate_date['雨量数量'] = item.get("ylCount")
        if item.get("qwRate") is not None:
            rate_date['气温在线率'] = item.get("qwRate")
        if item.get("qwCount") is not None:
            rate_date['气温数量'] = item.get("qwCount")
        if item.get("twRate") is not None:
            rate_date['土壤湿度在线率'] = item.get("twRate")
        if item.get("twCount") is not None:
            rate_date['土壤湿度数量'] = item.get("twCount")
        if item.get("hsRate") is not None:
            rate_date['土壤含水率在线率'] = item.get("hsRate")
        if item.get("hsCount") is not None:
            rate_date['土壤含水率数量'] = item.get("hsCount")
        if item.get("dbRate") is not None:
            rate_date['地表水温/水位在线率'] = item.get("dbRate")
        if item.get("dbCount") is not None:
            rate_date['地表水温/水位数量'] = item.get("dbCount")
        if item.get("syRate") is not None:
            rate_date['孔隙水温/水压在线率'] = item.get("syRate")
        if item.get("syCount") is not None:
            rate_date['孔隙水温/水压数量'] = item.get("syCount")
        if item.get("stRate") is not None:
            rate_date['渗透压力在线率'] = item.get("stRate")
        if item.get("stCount") is not None:
            rate_date['渗透压力数量'] = item.get("stCount")
        if item.get("lsRate") is not None:
            rate_date['流速在线率'] = item.get("lsRate")
        if item.get("lsCount") is not None:
            rate_date['流速数量'] = item.get("lsCount")
        if item.get("cjRate") is not None:
            rate_date['沉降在线率'] = item.get("cjRate")
        if item.get("cjCount") is not None:
            rate_date['沉降数量'] = item.get("cjCount")
        if item.get("qyRate") is not None:
            rate_date['气压在线率'] = item.get("qyRate")
        if item.get("qyCount") is not None:
            rate_date['气压数量'] = item.get("qyCount")
        if item.get("spRate") is not None:
            rate_date['视频在线率'] = item.get("spRate")
        if item.get("spCount") is not None:
            rate_date['视频数量'] = item.get("spCount")
        if item.get("nwRate") is not None:
            rate_date['泥水位在线率'] = item.get("nwRate")
        if item.get("nwCount") is not None:
            rate_date['泥水位数量'] = item.get("nwCount")
        if item.get("ldRate") is not None:
            rate_date['雷达在线率'] = item.get("ldRate")
        if item.get("ldCount") is not None:
            rate_date['雷达数量'] = item.get("ldCount")
        if item.get("lbRate") is not None:
            rate_date['预警喇叭在线率'] = item.get("lbRate")   
        if item.get("lbCount") is not None:
            rate_date['预警喇叭数量'] = item.get("lbCount")
        return rate_date

315 316 317 318 319 320 321 322 323
    def to_llm(self, region_name: str,start_time: str, end_time: str, data: Dict[str, Any]) -> str:
        """将数据转换为 LLM 可理解的格式"""
        self.logger.debug("开始将数据转换为 LLM 可理解的格式")
        llm_data = f"共找到{len(data)}条数据,在{start_time}至{end_time}期间,{region_name}的在线率数据如下:\n"
        for index, item in enumerate(data):
            llm_data += f"{index+1}. {item['name']} 的在线率为 {item['rate']}\n"

        return llm_data

324 325 326 327 328

    def to_markdown(self, data: List[Dict[str, Any]]) -> str:
        """将数据转换为 markdown 表格"""
        self.logger.debug("开始生成 markdown 表格")
        markdown = """
329
| 序号 | 地区 | 在线率 |
330 331 332 333 334 335 336
| --- | --- | --- | 
"""
        for index, row in enumerate(data):
            markdown += f"| {index+1} | {row['name']} | {row['rate']} | \n"
        
        self.logger.debug("markdown 表格生成完成")
        return markdown
337 338
class RankingRateArgs(BaseModel):
    """排名查询参数"""
339
    rate_type: int = Field(..., description="排序类型,用于指定查询的排名类别。1表示省份排名,2表示厂商排名")
340 341 342

class RankingRateTool(BaseRateTool):
    """查询在线率排名的工具"""
文靖昊 committed
343 344
    name:str = "online_rate_ranking"
    description:str = "查询设备在线率的排名数据,可查询省份排名或厂商排名"
345
    args_schema: Type[BaseModel] = RankingRateArgs
tinywell committed
346 347 348 349 350
    client: Any = Field(None, exclude=True)
    
    def __init__(self, base_url: str = const_base_url, **data):
        super().__init__(**data)
        self.client = RateClient(base_url=base_url)
351
        self.logger.info(f"初始化 RankingRateTool,base_url: {base_url}")
352 353 354 355
    
    def _run(self, rate_type: int) -> Dict[str, Any]:
        return self.get_ranking_data(rate_type)
        
tinywell committed
356 357
    def get_ranking_data(self, rate_type: int) -> Dict[str, Any]:
        if rate_type == 1:
358 359 360 361 362 363
            return self._get_province_ranking()
        else:
            return self._get_manufacturer_ranking()
    
    def _get_province_ranking(self) -> Dict[str, Any]:
        """获取省份在线率排名"""
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
        self.logger.info("开始查询省份在线率排名")
        try:
            df = self.client.query_rates_ranking_sync(rank_type=1)
            
            if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
                error_msg = '未找到省份在线率排名数据,请检查是否有相关数据权限'
                self.logger.warning(error_msg)
                return {'code': 400, 'message': error_msg}
            
            self.logger.debug(f"省份排名数据: {df.resultdata}")
            markdown = self.to_markdown(df.resultdata)
            
            data = {
                'rankings': df.resultdata,
                'total_provinces': len(df.resultdata),
                'best_province': {
                    'name': df.resultdata[0]['name'],
                    'rate': df.resultdata[0]['onlineRate']
                },
                'markdown': markdown
tinywell committed
384
            }
385 386 387 388 389 390 391
            
            self.logger.info(f"查询完成,共 {len(df.resultdata)} 个省份")
            return data
            
        except Exception as e:
            self.logger.error(f"查询省份排名失败: {str(e)}", exc_info=True)
            raise
392 393 394
        
    def _get_manufacturer_ranking(self) -> Dict[str, Any]:
        """获取厂商在线率排名"""
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
        self.logger.info("开始查询厂商在线率排名")
        try:
            df = self.client.query_rates_ranking_sync(rank_type=2)
            
            if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
                error_msg = '未找到厂商在线率排名数据,请检查是否有相关数据权限'
                self.logger.warning(error_msg)
                return {'code': 400, 'message': error_msg}
            
            self.logger.debug(f"厂商排名数据: {df.resultdata}")
            markdown = self.to_markdown(df.resultdata)
            
            data = {
                'rankings': df.resultdata,
                'total_manufacturers': len(df.resultdata),
                'best_manufacturer': {
                    'name': df.resultdata[0]['name'],
                    'rate': df.resultdata[0]['onlineRate']
                },
                'markdown': markdown
415
            }
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
            
            self.logger.info(f"查询完成,共 {len(df.resultdata)} 个厂商")
            return data
            
        except Exception as e:
            self.logger.error(f"查询厂商排名失败: {str(e)}", exc_info=True)
            raise

    def to_markdown(self, data: List[Dict[str, Any]]) -> str:
        """将数据转换为 markdown 表格"""
        self.logger.debug("开始生成 markdown 表格")
        markdown = """
| 序号 | 名称 | 全称 | 在线率 |
| --- | --- | --- | --- | 
"""
        for index, row in enumerate(data):
            markdown += f"| {index+1} | {row['name']} | {row['fullname']} | {row['onlineRate']} | \n"
433
        
434 435
        self.logger.debug("markdown 表格生成完成")
        return markdown