tool_rate.py 20.2 KB
Newer Older
文靖昊 committed
1
import time
tinywell committed
2
import logging
3 4 5 6 7
from typing import Dict, List, Tuple, Any, Optional
from pydantic import BaseModel, Field
from typing import Type
from langchain_core.tools import BaseTool

tinywell committed
8
from .http_tools import RateClient, const_base_url
9
from .code import AreaCodeTool
10
from ..utils.logger import get_logger
11 12

code_tool = AreaCodeTool()
13

14 15
class BaseRateTool(BaseTool):
    """设备在线率分析基础工具类"""
16
    logger: logging.Logger = Field(None, exclude=True)
17
    
tinywell committed
18
    def __init__(self, **data):
19
        super().__init__(**data)
20
        self.logger = get_logger(self.__class__.__name__)
21
        
tinywell committed
22
    def format_response(self, data: Dict[str, Any], chart: Any) -> Dict[str, Any]:
23
        """格式化返回结果"""
24
        self.logger.debug("格式化返回结果")
25 26 27 28 29 30 31
        return {
            'data': data,
            'summary': self._generate_summary(data)
        }
    
    def _generate_summary(self, data: Dict[str, Any]) -> str:
        """生成数据分析总结文本"""
32
        self.logger.debug("生成数据分析总结")
33
        if 'trend_data' in data:  # 全国趋势数据
34
            summary = (
35 36 37 38 39 40 41 42
                f"在{data['date_range']['start']}至{data['date_range']['end']}期间,"
                f"全国平均在线率为{data['statistics']['average_rate']:.2%},"
                f"最高达到{data['statistics']['max_rate']:.2%},"
                f"最低为{data['statistics']['min_rate']:.2%}。"
                f"平均设备总数约{data['statistics']['average_devices']:,}台。"
            )
        elif 'rankings' in data:  # 排名数据
            if 'total_provinces' in data:  # 省份排名
43
                summary = (
44 45 46 47 48 49
                    f"共分析了{data['total_provinces']}个省份的在线率数据,"
                    f"平均在线率为{data['average_rate']:.2%}。"
                    f"{data['best_province']['name']}的表现最好,"
                    f"在线率达到{data['best_province']['rate']:.2%}。"
                )
            else:  # 厂商排名
50
                summary = (
51 52 53 54 55 56
                    f"共分析了{data['total_manufacturers']}个厂商的在线率数据,"
                    f"平均在线率为{data['average_rate']:.2%}。"
                    f"{data['best_manufacturer']['name']}的表现最好,"
                    f"在线率达到{data['best_manufacturer']['rate']:.2%}。"
                )
        else:  # 地区在线率数据
57
            summary = (
58 59 60 61 62 63
                f"{data['region']}在{data['date_range']['start']}至{data['date_range']['end']}期间,"
                f"平均在线率为{data['average_rate']:.2%},"
                f"最高达到{data['max_rate']:.2%},"
                f"最低为{data['min_rate']:.2%}。"
                f"平均设备数约{int(data['total_devices']):,}台。"
            )
64 65
        self.logger.debug(f"生成的总结: {summary}")
        return summary
66 67 68

class RegionRateArgs(BaseModel):
    """地区在线率查询参数"""
69 70 71
    region_name: str = Field("", description="地区名称,如果要查询全国数据,请输入空字符串")
    start_time: str = Field("", description="开始时间 (YYYY-MM-DD)")
    end_time: str = Field("", description="结束时间 (YYYY-MM-DD)")
72
    month_statistics: bool = Field(False, description="是否按月度查看一年内的在线率统计结果,默认不需要")
73 74
    device_items: str = Field("", description="监测项集合,多个逗号隔开,支持:滑坡仪、设备、传感器、雨量、地表位移、裂缝、倾角、加速度、土壤含水率、泥水位,默认为空")
    manufacturer_name: str = Field("", description="设备厂商,默认为空")
75 76

class RegionRateTool(BaseRateTool):
tinywell committed
77
    """查询全国或者特定地区设备在线率的工具"""
文靖昊 committed
78 79
    name:str = "region_online_rate"
    description:str = "查询指定地区在指定时间段内的设备在线率"
80
    args_schema: Type[BaseModel] = RegionRateArgs
81 82
    client: Any = Field(None, exclude=True)
    
tinywell committed
83
    def __init__(self, base_url: str = const_base_url, **data):
84 85
        super().__init__(**data)
        self.client = RateClient(base_url=base_url)
86
        self.logger.info(f"初始化 RegionRateTool,base_url: {base_url}")
87
    
88
    def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False, device_items: str="", manufacturer_name: str="") -> Dict[str, Any]:
tinywell committed
89 90
        code = ""
        if region_name != "":
91
            self.logger.debug(f"查找区域代码: {region_name}")
tinywell committed
92 93
            codes = code_tool.find_code(region_name)
            if codes is None or len(codes) == 0:
94 95 96
                error_msg = f'未找到匹配的区域代码: {region_name}'
                self.logger.warning(error_msg)
                return {'code': 400, 'message': error_msg}
tinywell committed
97
            code = codes[0][1]
98 99
            self.logger.debug(f"找到区域代码: {code}")

100 101
        if month_statistics:
            return self.get_region_online_rate_of_month(region_name, code, start_time, end_time) 
102
        else:
103 104 105 106 107 108 109 110 111 112 113
            return self.get_region_online_rate(start_time, end_time, region_name, code, month_statistics, device_items, manufacturer_name)
        
    def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="",code: str="", month_statistics: bool=False, device_items: str=None, manufacturer_name: str=None) -> Dict[str, Any]:
        agent_start = time.time()
        self.logger.info(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}, 厂商: {manufacturer_name}, 监测项: {device_items}, 按月度查询统计: {month_statistics}")
        
        try:
            start = time.time()
            df = self.client.query_rates_sync(code, start_time, end_time, manufacturer_name, device_items)
            query_time = time.time() - start
            self.logger.debug(f"API调用耗时: {query_time:.2f}秒")
114
                
115 116 117
            if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
                error_msg = f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
                self.logger.warning(error_msg)
118 119
                return {'code': 400, 'message': error_msg}
            
120 121 122 123 124 125 126
            self.logger.debug(f"查询结果: {df.resultdata}")
            # markdown = self.to_markdown(df.resultdata)
            
            result_data = []
            for item in df.resultdata:
                rate_date = self._extract_rate_data(item)
                result_data.append(rate_date)
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            sorted_data = [rate_date[0]]
            if len(result_data) > 1:
                sorted_data.extend(sorted(result_data[1:], key=lambda x: x['在线率']))
            
            # 将数据改成表格形式,表头提取之后,数据一行行显示
            table_header = list(result_data[0].keys())
            table_data = []
            for item in result_data:
                table_data.append(list(item.values()))
            
            result_data = {
                'table_header': table_header,
                'table_data': table_data
            }
            
142
            data = {
143 144 145
                'region': region_name,
                'region_code': code,
                'rate_data': result_data,
146
            }
147
            
148 149 150
            total_time = time.time() - agent_start
            self.logger.info(f"查询完成 {region_name}(code: {code}) 在线率,总耗时: {total_time:.2f}秒")
            return data
151
            
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
        except Exception as e:
            self.logger.error(f"查询失败: {str(e)}", exc_info=True)
            raise

    def get_region_online_rate_of_month(self, region_name: str, code: str, start_time: str, end_time: str,device_items: str=None) -> Dict[str, Any]:
        """查询指定地区在指定时间段内的月度在线率"""
        year = start_time.split('-')[0]
        self.logger.debug(f"按月度查询,年份: {year}")  
        try:    
            df = self.client.query_rates_month_sync(year, code, device_items)

            if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
                error_msg = f'未找到{region_name}(code: {code})在{year}年内的数据,请检查是否有相关数据权限'
                self.logger.warning(error_msg)
                return {'code': 400, 'message': error_msg}

            result_data = {}
            for index, item in enumerate(df.resultdata):
                month_data = self._extract_rate_data(item)
                result_data[item['month']] = month_data
            # 排序
tinywell committed
173 174
            result_data = sorted(result_data.items(), key=lambda x: x[0])

175 176
            self.logger.debug(f"查询结果: {df.resultdata}")
            # markdown = self.to_markdown(df.resultdata)
177 178 179 180 181 182 183 184 185 186
            
            table_header = list(result_data[0].keys())
            table_data = []
            for item in result_data:
                table_data.append(list(item.values()))
            
            result_data = {
                'table_header': table_header,
                'table_data': table_data
            }


            data = {
                'region': region_name,
                'region_code': code,
                'rate_data': result_data
            }

            return data

        except Exception as e:
            self.logger.error(f"查询失败: {str(e)}", exc_info=True)
            raise

    def _extract_rate_data(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """提取月度数据"""
        rate_date = {}
        if item.get("name") is not None:
            rate_date['名称'] = item.get("name")
        if item.get("rate") is not None:
            rate_date['在线率'] = item.get("rate")
        if item.get("rtuCount") is not None:
            rate_date['在线数量'] = item.get("rtuCount")
        if item.get("monitorRate") is not None:
            rate_date['滑坡仪在线率'] = item.get("monitorRate")
        if item.get("sensorRate") is not None:
            rate_date['传感器在线率'] = item.get("sensorRate")
        if item.get("monitorCount") is not None:
            rate_date['滑坡仪数量'] = item.get("monitorCount")
        if item.get("sensorCount") is not None:
            rate_date['传感器数量'] = item.get("sensorCount")
        if item.get("lfRate") is not None:
            rate_date['裂缝在线率'] = item.get("lfRate")
        if item.get("lfCount") is not None:
            rate_date['裂缝数量'] = item.get("lfCount")
        if item.get("gpRate") is not None:
            rate_date['地表位移在线率'] = item.get("gpRate")
        if item.get("gpCount") is not None:
            rate_date['地表位移数量'] = item.get("gpCount")
        if item.get("swRate") is not None:
            rate_date['深部位移在线率'] = item.get("swRate")
        if item.get("swCount") is not None:
            rate_date['深部位移数量'] = item.get("swCount")
        if item.get("jsRate") is not None:
            rate_date['加速度在线率'] = item.get("jsRate")
        if item.get("jsCount") is not None:
            rate_date['加速度数量'] = item.get("jsCount")
        if item.get("qjRate") is not None:
            rate_date['倾角在线率'] = item.get("qjRate")
        if item.get("qjCount") is not None:
            rate_date['倾角数量'] = item.get("qjCount")
        if item.get("zdRate") is not None:
            rate_date['振动在线率'] = item.get("zdRate")
        if item.get("zdCount") is not None:
            rate_date['振动数量'] = item.get("zdCount")
        if item.get("ylRate") is not None:
            rate_date['应力在线率'] = item.get("ylRate")
        if item.get("ylCount") is not None:
            rate_date['应力数量'] = item.get("ylCount")
        if item.get("tyRate") is not None:
            rate_date['土压力在线率'] = item.get("tyRate")
        if item.get("tyCount") is not None:
            rate_date['土压力数量'] = item.get("tyCount")
        if item.get("csRate") is not None:
            rate_date['次声在线率'] = item.get("csRate")
        if item.get("csCount") is not None:
            rate_date['次声数量'] = item.get("csCount")
        if item.get("dsRate") is not None:
            rate_date['地声在线率'] = item.get("dsRate")
        if item.get("dsCount") is not None:
            rate_date['地声数量'] = item.get("dsCount")
        if item.get("ylRate") is not None:
            rate_date['雨量在线率'] = item.get("ylRate")
        if item.get("ylCount") is not None:
            rate_date['雨量数量'] = item.get("ylCount")
        if item.get("qwRate") is not None:
            rate_date['气温在线率'] = item.get("qwRate")
        if item.get("qwCount") is not None:
            rate_date['气温数量'] = item.get("qwCount")
        if item.get("twRate") is not None:
            rate_date['土壤湿度在线率'] = item.get("twRate")
        if item.get("twCount") is not None:
            rate_date['土壤湿度数量'] = item.get("twCount")
        if item.get("hsRate") is not None:
            rate_date['土壤含水率在线率'] = item.get("hsRate")
        if item.get("hsCount") is not None:
            rate_date['土壤含水率数量'] = item.get("hsCount")
        if item.get("dbRate") is not None:
            rate_date['地表水温/水位在线率'] = item.get("dbRate")
        if item.get("dbCount") is not None:
            rate_date['地表水温/水位数量'] = item.get("dbCount")
        if item.get("syRate") is not None:
            rate_date['孔隙水温/水压在线率'] = item.get("syRate")
        if item.get("syCount") is not None:
            rate_date['孔隙水温/水压数量'] = item.get("syCount")
        if item.get("stRate") is not None:
            rate_date['渗透压力在线率'] = item.get("stRate")
        if item.get("stCount") is not None:
            rate_date['渗透压力数量'] = item.get("stCount")
        if item.get("lsRate") is not None:
            rate_date['流速在线率'] = item.get("lsRate")
        if item.get("lsCount") is not None:
            rate_date['流速数量'] = item.get("lsCount")
        if item.get("cjRate") is not None:
            rate_date['沉降在线率'] = item.get("cjRate")
        if item.get("cjCount") is not None:
            rate_date['沉降数量'] = item.get("cjCount")
        if item.get("qyRate") is not None:
            rate_date['气压在线率'] = item.get("qyRate")
        if item.get("qyCount") is not None:
            rate_date['气压数量'] = item.get("qyCount")
        if item.get("spRate") is not None:
            rate_date['视频在线率'] = item.get("spRate")
        if item.get("spCount") is not None:
            rate_date['视频数量'] = item.get("spCount")
        if item.get("nwRate") is not None:
            rate_date['泥水位在线率'] = item.get("nwRate")
        if item.get("nwCount") is not None:
            rate_date['泥水位数量'] = item.get("nwCount")
        if item.get("ldRate") is not None:
            rate_date['雷达在线率'] = item.get("ldRate")
        if item.get("ldCount") is not None:
            rate_date['雷达数量'] = item.get("ldCount")
        if item.get("lbRate") is not None:
            rate_date['预警喇叭在线率'] = item.get("lbRate")   
        if item.get("lbCount") is not None:
            rate_date['预警喇叭数量'] = item.get("lbCount")
        return rate_date

315 316 317 318 319 320 321 322 323
    def to_llm(self, region_name: str,start_time: str, end_time: str, data: Dict[str, Any]) -> str:
        """将数据转换为 LLM 可理解的格式"""
        self.logger.debug("开始将数据转换为 LLM 可理解的格式")
        llm_data = f"共找到{len(data)}条数据,在{start_time}至{end_time}期间,{region_name}的在线率数据如下:\n"
        for index, item in enumerate(data):
            llm_data += f"{index+1}. {item['name']} 的在线率为 {item['rate']}\n"

        return llm_data

324 325 326 327 328

    def to_markdown(self, data: List[Dict[str, Any]]) -> str:
        """将数据转换为 markdown 表格"""
        self.logger.debug("开始生成 markdown 表格")
        markdown = """
329
| 序号 | 地区 | 在线率 |
330 331 332 333 334 335 336
| --- | --- | --- | 
"""
        for index, row in enumerate(data):
            markdown += f"| {index+1} | {row['name']} | {row['rate']} | \n"
        
        self.logger.debug("markdown 表格生成完成")
        return markdown
337 338
class RankingRateArgs(BaseModel):
    """排名查询参数"""
339
    rate_type: int = Field(..., description="排序类型,用于指定查询的排名类别。1表示省份排名,2表示厂商排名")
340 341 342

class RankingRateTool(BaseRateTool):
    """查询在线率排名的工具"""
文靖昊 committed
343 344
    name:str = "online_rate_ranking"
    description:str = "查询设备在线率的排名数据,可查询省份排名或厂商排名"
345
    args_schema: Type[BaseModel] = RankingRateArgs
tinywell committed
346 347 348 349 350
    client: Any = Field(None, exclude=True)
    
    def __init__(self, base_url: str = const_base_url, **data):
        super().__init__(**data)
        self.client = RateClient(base_url=base_url)
351
        self.logger.info(f"初始化 RankingRateTool,base_url: {base_url}")
352 353 354 355
    
    def _run(self, rate_type: int) -> Dict[str, Any]:
        return self.get_ranking_data(rate_type)
        
tinywell committed
356 357
    def get_ranking_data(self, rate_type: int) -> Dict[str, Any]:
        if rate_type == 1:
358 359 360 361 362 363
            return self._get_province_ranking()
        else:
            return self._get_manufacturer_ranking()
    
    def _get_province_ranking(self) -> Dict[str, Any]:
        """获取省份在线率排名"""
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
        self.logger.info("开始查询省份在线率排名")
        try:
            df = self.client.query_rates_ranking_sync(rank_type=1)
            
            if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
                error_msg = '未找到省份在线率排名数据,请检查是否有相关数据权限'
                self.logger.warning(error_msg)
                return {'code': 400, 'message': error_msg}
            
            self.logger.debug(f"省份排名数据: {df.resultdata}")
            markdown = self.to_markdown(df.resultdata)
            
            data = {
                'rankings': df.resultdata,
                'total_provinces': len(df.resultdata),
                'best_province': {
                    'name': df.resultdata[0]['name'],
                    'rate': df.resultdata[0]['onlineRate']
                },
                'markdown': markdown
tinywell committed
384
            }
385 386 387 388 389 390 391
            
            self.logger.info(f"查询完成,共 {len(df.resultdata)} 个省份")
            return data
            
        except Exception as e:
            self.logger.error(f"查询省份排名失败: {str(e)}", exc_info=True)
            raise
392 393 394
        
    def _get_manufacturer_ranking(self) -> Dict[str, Any]:
        """获取厂商在线率排名"""
395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
        self.logger.info("开始查询厂商在线率排名")
        try:
            df = self.client.query_rates_ranking_sync(rank_type=2)
            
            if df.type != 1 or df.resultdata is None or len(df.resultdata) == 0:
                error_msg = '未找到厂商在线率排名数据,请检查是否有相关数据权限'
                self.logger.warning(error_msg)
                return {'code': 400, 'message': error_msg}
            
            self.logger.debug(f"厂商排名数据: {df.resultdata}")
            markdown = self.to_markdown(df.resultdata)
            
            data = {
                'rankings': df.resultdata,
                'total_manufacturers': len(df.resultdata),
                'best_manufacturer': {
                    'name': df.resultdata[0]['name'],
                    'rate': df.resultdata[0]['onlineRate']
                },
                'markdown': markdown
415
            }
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432
            
            self.logger.info(f"查询完成,共 {len(df.resultdata)} 个厂商")
            return data
            
        except Exception as e:
            self.logger.error(f"查询厂商排名失败: {str(e)}", exc_info=True)
            raise

    def to_markdown(self, data: List[Dict[str, Any]]) -> str:
        """将数据转换为 markdown 表格"""
        self.logger.debug("开始生成 markdown 表格")
        markdown = """
| 序号 | 名称 | 全称 | 在线率 |
| --- | --- | --- | --- | 
"""
        for index, row in enumerate(data):
            markdown += f"| {index+1} | {row['name']} | {row['fullname']} | {row['onlineRate']} | \n"
433
        
434 435
        self.logger.debug("markdown 表格生成完成")
        return markdown