Commit 991ebf04 by tinywell

预警接口及工具

parent d7261e1c
...@@ -3,6 +3,9 @@ from typing import TypeVar, Generic, Any, Optional, List, Dict ...@@ -3,6 +3,9 @@ from typing import TypeVar, Generic, Any, Optional, List, Dict
from pydantic import BaseModel from pydantic import BaseModel
from urllib.parse import urljoin from urllib.parse import urljoin
import time
from ..utils.logger import get_logger
# 泛型类型定义 # 泛型类型定义
T = TypeVar('T') T = TypeVar('T')
...@@ -30,7 +33,8 @@ class BaseHttpClient: ...@@ -30,7 +33,8 @@ class BaseHttpClient:
"""基础HTTP客户端""" """基础HTTP客户端"""
def __init__(self, base_url: str = const_base_url): def __init__(self, base_url: str = const_base_url):
self.base_url = base_url.rstrip('/') self.base_url = base_url.rstrip('/')
self.timeout = 30.0 self.timeout = 60.0
self.logger = get_logger(self.__class__.__name__)
async def _request_async(self, method: str, endpoint: str, **kwargs) -> Any: async def _request_async(self, method: str, endpoint: str, **kwargs) -> Any:
"""通用异步请求方法""" """通用异步请求方法"""
...@@ -42,11 +46,17 @@ class BaseHttpClient: ...@@ -42,11 +46,17 @@ class BaseHttpClient:
def _request_sync(self, method: str, endpoint: str, **kwargs) -> Any: def _request_sync(self, method: str, endpoint: str, **kwargs) -> Any:
"""通用同步请求方法""" """通用同步请求方法"""
self.logger.info(f"请求URL: {urljoin(self.base_url, endpoint)},请求参数: {kwargs}")
start_time = time.time()
result = None
with httpx.Client(timeout=self.timeout) as client: with httpx.Client(timeout=self.timeout) as client:
url = urljoin(self.base_url, endpoint) url = urljoin(self.base_url, endpoint)
response = client.request(method, url, **kwargs) response = client.request(method, url, **kwargs)
response.raise_for_status() response.raise_for_status()
return response.json() result = response.json()
end_time = time.time()
self.logger.info(f"请求耗时: {end_time - start_time}秒")
return result
class MonitorPoint(BaseModel): class MonitorPoint(BaseModel):
...@@ -195,14 +205,14 @@ class RateClient(BaseHttpClient): ...@@ -195,14 +205,14 @@ class RateClient(BaseHttpClient):
) )
return BaseResponse[List](**data) return BaseResponse[List](**data)
async def query_rates_ranking(self, rank_type: int) -> BaseResponse[List]: async def query_rates_ranking(self, rank_type: int) -> BaseResponse[Dict]:
"""异步查询在线率排名信息""" """异步查询在线率排名信息"""
data = await self._request_async( data = await self._request_async(
"POST", "POST",
const_url_rate_ranking, const_url_rate_ranking,
json={'type': rank_type} json={'type': rank_type}
) )
return BaseResponse[List](**data) return BaseResponse[Dict](**data)
def query_rates_month_sync(self, year: str, areaCode: str, typeArr: str) -> BaseResponse[List]: def query_rates_month_sync(self, year: str, areaCode: str, typeArr: str) -> BaseResponse[List]:
"""同步查询按月度统计的在线率信息""" """同步查询按月度统计的在线率信息"""
...@@ -220,7 +230,7 @@ class RateClient(BaseHttpClient): ...@@ -220,7 +230,7 @@ class RateClient(BaseHttpClient):
class WarningClient(BaseHttpClient): class WarningClient(BaseHttpClient):
"""预警查询客户端""" """预警查询客户端"""
def query_warning_statistics(self, start_time: str, end_time: str, area_code: str) -> BaseResponse[List]: def query_warning_statistics(self, start_time: str, end_time: str, area_code: str) -> BaseResponse[Dict]:
"""同步查询预警统计信息""" """同步查询预警统计信息"""
data = self._request_sync( data = self._request_sync(
"POST", "POST",
...@@ -231,7 +241,7 @@ class WarningClient(BaseHttpClient): ...@@ -231,7 +241,7 @@ class WarningClient(BaseHttpClient):
'areaCode': area_code 'areaCode': area_code
} }
) )
return BaseResponse[List](**data) return BaseResponse[Dict](**data)
def query_warning_month_statistics(self, year: str, areaCode: str) -> BaseResponse[List]: def query_warning_month_statistics(self, year: str, areaCode: str) -> BaseResponse[List]:
"""同步查询按月度统计的预警统计信息""" """同步查询按月度统计的预警统计信息"""
......
...@@ -156,7 +156,8 @@ class RegionRateTool(BaseRateTool): ...@@ -156,7 +156,8 @@ class RegionRateTool(BaseRateTool):
month_data = self._extract_rate_data(item) month_data = self._extract_rate_data(item)
result_data[item['month']] = month_data result_data[item['month']] = month_data
# 排序 # 排序
result_data = sorted(result_data.items(), key=lambda x: x[0])
self.logger.debug(f"查询结果: {df.resultdata}") self.logger.debug(f"查询结果: {df.resultdata}")
# markdown = self.to_markdown(df.resultdata) # markdown = self.to_markdown(df.resultdata)
......
...@@ -7,30 +7,128 @@ ...@@ -7,30 +7,128 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Any, Dict, Type from typing import Any, Dict, Type
import logging
from langchain_core.tools import BaseTool
from .http_tools import WarningClient, const_base_url from .http_tools import WarningClient, const_base_url
from .code import AreaCodeTool
from ..utils.logger import get_logger
code_tool = AreaCodeTool()
class WarningArgs(BaseModel): class WarningArgs(BaseModel):
"""预警查询参数""" """预警查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)") start_time: str = Field("", description="开始时间 (YYYY-MM-DD HH:mm:ss)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)") end_time: str = Field("", description="结束时间 (YYYY-MM-DD HH:mm:ss)")
region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串") region_name: str = Field("", description="地区名称,如果要查询全国数据,请输入空字符串")
month_statistics: bool = Field(False, description="是否按月度查询一年内的统计结果,默认不需要")
class WarningTool(BaseTool): class WarningTool(BaseTool):
"""查询预警处置和虚警情况""" """查询预警处置和虚警情况"""
name: str = "warning_statistics" name: str = "warning_statistics"
description: str = "查询不同时间段不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持按月度统计一年内的虚警率,处置率。" description: str = "查询一定时间范围内不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持查询指定年份全年按月度统计的虚警率,处置率等信息。"
args_schema: Type[BaseModel] = WarningArgs args_schema: Type[BaseModel] = WarningArgs
client: Any = Field(None, exclude=True) client: Any = Field(None, exclude=True)
logger: logging.Logger = Field(None, exclude=True)
def __init__(self, base_url: str = const_base_url, **data): def __init__(self, base_url: str = const_base_url, **data):
super().__init__(**data) super().__init__(**data)
self.client = WarningClient(base_url=base_url) self.client = WarningClient(base_url=base_url)
self.logger = get_logger("WarningTool")
self.logger.info(f"初始化 WarningTool,base_url: {base_url}") self.logger.info(f"初始化 WarningTool,base_url: {base_url}")
def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False) -> Dict[str, Any]: def _run(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]:
return self.get_warning_statistics(start_time, end_time, region_name, month_statistics) code = ""
if region_name != "":
self.logger.debug(f"查找区域代码: {region_name}")
codes = code_tool.find_code(region_name)
if codes is None or len(codes) == 0:
error_msg = f'未找到匹配的区域代码: {region_name}'
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
code = codes[0][1]
self.logger.debug(f"找到区域代码: {code}")
year = start_time.split("-")[0]
detail = self.get_warning_statistics(start_time, end_time, region_name, code)
monthly = self.get_warning_statistics_of_month(year, region_name, code)
if monthly['code'] != 200:
return monthly
return {
"整体详细数据": detail['data'] if detail['code'] == 200 else detail['message'],
"各月数据": monthly['data'] if monthly['code'] == 200 else monthly['message']
}
def get_warning_statistics(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False) -> Dict[str, Any]: def get_warning_statistics(self, start_time: str, end_time: str, region_name: str="", code: str="") -> Dict[str, Any]:
pass try:
\ No newline at end of file response = self.client.query_warning_statistics(start_time, end_time, code)
self.logger.debug(f"API响应: {response}")
if response.type != 1 or len(response.resultdata) == 0:
error_msg = f"查询失败: {response.message},请检查是否有相关数据权限"
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
data = {
"预警消息个数": response.resultdata["num"],
"处置消息个数": response.resultdata["closenum"],
"处置率": response.resultdata["closeper"],
"虚警消息个数": response.resultdata["falsenum"],
"虚警率": response.resultdata["falseper"],
"红色预警消息个数": response.resultdata["rednum"],
"红色处置消息个数": response.resultdata["redcloseper"],
"橙色预警消息个数": response.resultdata["orangenum"],
"橙色处置消息个数": response.resultdata["orangecloseper"],
"黄色预警消息个数": response.resultdata["yellownum"],
"黄色处置消息个数": response.resultdata["yellowcloseper"],
"蓝色预警消息个数": response.resultdata["bluenum"],
"蓝色处置消息个数": response.resultdata["bluecloseper"],
"数据异常消息个数": response.resultdata["datanum"],
"数据异常消息占比": response.resultdata["datacloseper"],
"设备维护": response.resultdata["devicemainnum"],
"设备维护占比": response.resultdata["devicemaincloseper"],
"设备遭到破坏": response.resultdata["damagenum"],
"设备遭到破坏占比": response.resultdata["damagecloseper"],
"模型待优化": response.resultdata["modelnum"],
"模型待优化占比": response.resultdata["modelcloseper"],
}
return {
'code': 200,
'data': data
}
except Exception as e:
self.logger.error(f"查询预警统计信息失败: {e}")
return {'code': 400, 'message': str(e)}
def get_warning_statistics_of_month(self, year: str, region_name: str="", code: str="") -> Dict[str, Any]:
try:
response = self.client.query_warning_month_statistics(year, code)
self.logger.debug(f"API响应: {response}")
if response.type != 1 or len(response.resultdata) == 0:
error_msg = f"查询失败: {response.message},请检查是否有相关数据权限"
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
data = {}
for item in response.resultdata:
month = {
"预警数量": item["num"],
"处置率": item["closeper"],
"处置数量": item["closenum"],
"虚警率": item["falseper"],
"虚警数量": item["falsenum"],
}
data[item["month"]] = month
data = sorted(data.items(), key=lambda x: x[0])
return {
'code': 200,
'data': data
}
except Exception as e:
self.logger.error(f"查询预警统计信息失败: {e}")
return {'code': 400, 'message': str(e)}
from langchain_openai import ChatOpenAI
from rich.console import Console
from rich.table import Table
import sys,os
sys.path.append("../")
from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
from src.agent.tool_warn import WarningTool
def run_examples():
# 初始化 rich console
console = Console()
# 初始化 LLM
llm = ChatOpenAI(
openai_api_key="xxxxxx",
openai_api_base="http://192.168.10.14:8000/v1",
model_name="Qwen2-7B",
verbose=True
)
base_url = "http://172.30.0.37:30007"
# 初始化工具
tools = [
RegionRateTool(base_url=base_url),
RankingRateTool(base_url=base_url),
MonitorPointTool(base_url=base_url),
WarningTool(base_url=base_url),
]
tool_dict = {tool.name: tool for tool in tools}
# 初始化 ToolPicker
picker = ToolPicker(llm, tools)
# 测试案例和预期结果
test_cases = [
{
"query": "查询2024年4月到5月甘肃省的预警情况",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2024-04-01 00:00:00",
"end_time": "2024-05-31 23:59:59",
"region_name": "甘肃省",
}
}
},{
"query": "查询2024年4月到5月甘肃省的处置率是多少",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2024-04-01 00:00:00",
"end_time": "2024-05-31 23:59:59",
"region_name": "甘肃省",
}
}
},{
"query": "查询2024年甘肃省各月的预警情况总体分析",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2024-01-01 00:00:00",
"end_time": "2024-12-31 23:59:59",
"region_name": "甘肃省",
}
}
},{
"query": "查询2024年甘肃省上半年的虚警率是多少",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2024-01-01 00:00:00",
"end_time": "2024-06-30 23:59:59",
"region_name": "甘肃省",
}
}
}
]
# 为每个测试案例创建一个表格
for i, case in enumerate(test_cases, 1):
console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
table = Table(title=f"查询: {case['query']}")
table.add_column("项目", style="cyan")
table.add_column("预期结果", style="green")
table.add_column("实际结果", style="yellow")
table.add_column("是否匹配", style="magenta")
try:
result = picker.pick(case["query"])
# 添加工具比较行
expected_tool = case["expected"]["tool"]
actual_tool = result["tool"]
table.add_row(
"选择的工具",
expected_tool,
actual_tool,
"✓" if expected_tool == actual_tool else "✗"
)
# 添加参数比较行
for param_key in case["expected"]["params"]:
expected_value = str(case["expected"]["params"][param_key])
actual_value = str(result["params"].get(param_key, "未提供"))
table.add_row(
f"参数: {param_key}",
expected_value,
actual_value,
"✓" if expected_value == actual_value else "✗"
)
tool = tool_dict[result["tool"]]
params = result["params"]
result = tool.invoke(params)
print(result)
except Exception as e:
table.add_row("错误", "", str(e), "✗")
console.print(table)
console.print("=" * 80)
if __name__ == "__main__":
run_examples()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment