Commit 991ebf04 by tinywell

预警接口及工具

parent d7261e1c
......@@ -3,6 +3,9 @@ from typing import TypeVar, Generic, Any, Optional, List, Dict
from pydantic import BaseModel
from urllib.parse import urljoin
import time
from ..utils.logger import get_logger
# 泛型类型定义
T = TypeVar('T')
......@@ -30,7 +33,8 @@ class BaseHttpClient:
"""基础HTTP客户端"""
def __init__(self, base_url: str = const_base_url):
self.base_url = base_url.rstrip('/')
self.timeout = 30.0
self.timeout = 60.0
self.logger = get_logger(self.__class__.__name__)
async def _request_async(self, method: str, endpoint: str, **kwargs) -> Any:
"""通用异步请求方法"""
......@@ -42,11 +46,17 @@ class BaseHttpClient:
def _request_sync(self, method: str, endpoint: str, **kwargs) -> Any:
"""通用同步请求方法"""
self.logger.info(f"请求URL: {urljoin(self.base_url, endpoint)},请求参数: {kwargs}")
start_time = time.time()
result = None
with httpx.Client(timeout=self.timeout) as client:
url = urljoin(self.base_url, endpoint)
response = client.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
result = response.json()
end_time = time.time()
self.logger.info(f"请求耗时: {end_time - start_time}秒")
return result
class MonitorPoint(BaseModel):
......@@ -195,14 +205,14 @@ class RateClient(BaseHttpClient):
)
return BaseResponse[List](**data)
async def query_rates_ranking(self, rank_type: int) -> BaseResponse[List]:
async def query_rates_ranking(self, rank_type: int) -> BaseResponse[Dict]:
"""异步查询在线率排名信息"""
data = await self._request_async(
"POST",
const_url_rate_ranking,
json={'type': rank_type}
)
return BaseResponse[List](**data)
return BaseResponse[Dict](**data)
def query_rates_month_sync(self, year: str, areaCode: str, typeArr: str) -> BaseResponse[List]:
"""同步查询按月度统计的在线率信息"""
......@@ -220,7 +230,7 @@ class RateClient(BaseHttpClient):
class WarningClient(BaseHttpClient):
"""预警查询客户端"""
def query_warning_statistics(self, start_time: str, end_time: str, area_code: str) -> BaseResponse[List]:
def query_warning_statistics(self, start_time: str, end_time: str, area_code: str) -> BaseResponse[Dict]:
"""同步查询预警统计信息"""
data = self._request_sync(
"POST",
......@@ -231,7 +241,7 @@ class WarningClient(BaseHttpClient):
'areaCode': area_code
}
)
return BaseResponse[List](**data)
return BaseResponse[Dict](**data)
def query_warning_month_statistics(self, year: str, areaCode: str) -> BaseResponse[List]:
"""同步查询按月度统计的预警统计信息"""
......
......@@ -156,6 +156,7 @@ class RegionRateTool(BaseRateTool):
month_data = self._extract_rate_data(item)
result_data[item['month']] = month_data
# 排序
result_data = sorted(result_data.items(), key=lambda x: x[0])
self.logger.debug(f"查询结果: {df.resultdata}")
# markdown = self.to_markdown(df.resultdata)
......
......@@ -7,30 +7,128 @@
from pydantic import BaseModel, Field
from typing import Any, Dict, Type
import logging
from langchain_core.tools import BaseTool
from .http_tools import WarningClient, const_base_url
from .code import AreaCodeTool
from ..utils.logger import get_logger
code_tool = AreaCodeTool()
class WarningArgs(BaseModel):
"""预警查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串")
month_statistics: bool = Field(False, description="是否按月度查询一年内的统计结果,默认不需要")
start_time: str = Field("", description="开始时间 (YYYY-MM-DD HH:mm:ss)")
end_time: str = Field("", description="结束时间 (YYYY-MM-DD HH:mm:ss)")
region_name: str = Field("", description="地区名称,如果要查询全国数据,请输入空字符串")
class WarningTool(BaseTool):
"""查询预警处置和虚警情况"""
name: str = "warning_statistics"
description: str = "查询不同时间段不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持按月度统计一年内的虚警率,处置率。"
description: str = "查询一定时间范围内不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持查询指定年份全年按月度统计的虚警率,处置率等信息。"
args_schema: Type[BaseModel] = WarningArgs
client: Any = Field(None, exclude=True)
logger: logging.Logger = Field(None, exclude=True)
def __init__(self, base_url: str = const_base_url, **data):
super().__init__(**data)
self.client = WarningClient(base_url=base_url)
self.logger = get_logger("WarningTool")
self.logger.info(f"初始化 WarningTool,base_url: {base_url}")
def _run(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False) -> Dict[str, Any]:
return self.get_warning_statistics(start_time, end_time, region_name, month_statistics)
def get_warning_statistics(self, start_time: str, end_time: str, region_name: str="", month_statistics: bool=False) -> Dict[str, Any]:
pass
\ No newline at end of file
def _run(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]:
code = ""
if region_name != "":
self.logger.debug(f"查找区域代码: {region_name}")
codes = code_tool.find_code(region_name)
if codes is None or len(codes) == 0:
error_msg = f'未找到匹配的区域代码: {region_name}'
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
code = codes[0][1]
self.logger.debug(f"找到区域代码: {code}")
year = start_time.split("-")[0]
detail = self.get_warning_statistics(start_time, end_time, region_name, code)
monthly = self.get_warning_statistics_of_month(year, region_name, code)
if monthly['code'] != 200:
return monthly
return {
"整体详细数据": detail['data'] if detail['code'] == 200 else detail['message'],
"各月数据": monthly['data'] if monthly['code'] == 200 else monthly['message']
}
def get_warning_statistics(self, start_time: str, end_time: str, region_name: str="", code: str="") -> Dict[str, Any]:
try:
response = self.client.query_warning_statistics(start_time, end_time, code)
self.logger.debug(f"API响应: {response}")
if response.type != 1 or len(response.resultdata) == 0:
error_msg = f"查询失败: {response.message},请检查是否有相关数据权限"
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
data = {
"预警消息个数": response.resultdata["num"],
"处置消息个数": response.resultdata["closenum"],
"处置率": response.resultdata["closeper"],
"虚警消息个数": response.resultdata["falsenum"],
"虚警率": response.resultdata["falseper"],
"红色预警消息个数": response.resultdata["rednum"],
"红色处置消息个数": response.resultdata["redcloseper"],
"橙色预警消息个数": response.resultdata["orangenum"],
"橙色处置消息个数": response.resultdata["orangecloseper"],
"黄色预警消息个数": response.resultdata["yellownum"],
"黄色处置消息个数": response.resultdata["yellowcloseper"],
"蓝色预警消息个数": response.resultdata["bluenum"],
"蓝色处置消息个数": response.resultdata["bluecloseper"],
"数据异常消息个数": response.resultdata["datanum"],
"数据异常消息占比": response.resultdata["datacloseper"],
"设备维护": response.resultdata["devicemainnum"],
"设备维护占比": response.resultdata["devicemaincloseper"],
"设备遭到破坏": response.resultdata["damagenum"],
"设备遭到破坏占比": response.resultdata["damagecloseper"],
"模型待优化": response.resultdata["modelnum"],
"模型待优化占比": response.resultdata["modelcloseper"],
}
return {
'code': 200,
'data': data
}
except Exception as e:
self.logger.error(f"查询预警统计信息失败: {e}")
return {'code': 400, 'message': str(e)}
def get_warning_statistics_of_month(self, year: str, region_name: str="", code: str="") -> Dict[str, Any]:
try:
response = self.client.query_warning_month_statistics(year, code)
self.logger.debug(f"API响应: {response}")
if response.type != 1 or len(response.resultdata) == 0:
error_msg = f"查询失败: {response.message},请检查是否有相关数据权限"
self.logger.warning(error_msg)
return {'code': 400, 'message': error_msg}
data = {}
for item in response.resultdata:
month = {
"预警数量": item["num"],
"处置率": item["closeper"],
"处置数量": item["closenum"],
"虚警率": item["falseper"],
"虚警数量": item["falsenum"],
}
data[item["month"]] = month
data = sorted(data.items(), key=lambda x: x[0])
return {
'code': 200,
'data': data
}
except Exception as e:
self.logger.error(f"查询预警统计信息失败: {e}")
return {'code': 400, 'message': str(e)}
from langchain_openai import ChatOpenAI
from rich.console import Console
from rich.table import Table
import sys,os
sys.path.append("../")
from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool, RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
from src.agent.tool_warn import WarningTool
def run_examples():
# 初始化 rich console
console = Console()
# 初始化 LLM
llm = ChatOpenAI(
openai_api_key="xxxxxx",
openai_api_base="http://192.168.10.14:8000/v1",
model_name="Qwen2-7B",
verbose=True
)
base_url = "http://172.30.0.37:30007"
# 初始化工具
tools = [
RegionRateTool(base_url=base_url),
RankingRateTool(base_url=base_url),
MonitorPointTool(base_url=base_url),
WarningTool(base_url=base_url),
]
tool_dict = {tool.name: tool for tool in tools}
# 初始化 ToolPicker
picker = ToolPicker(llm, tools)
# 测试案例和预期结果
test_cases = [
{
"query": "查询2024年4月到5月甘肃省的预警情况",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2024-04-01 00:00:00",
"end_time": "2024-05-31 23:59:59",
"region_name": "甘肃省",
}
}
},{
"query": "查询2024年4月到5月甘肃省的处置率是多少",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2024-04-01 00:00:00",
"end_time": "2024-05-31 23:59:59",
"region_name": "甘肃省",
}
}
},{
"query": "查询2024年甘肃省各月的预警情况总体分析",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2024-01-01 00:00:00",
"end_time": "2024-12-31 23:59:59",
"region_name": "甘肃省",
}
}
},{
"query": "查询2024年甘肃省上半年的虚警率是多少",
"expected": {
"tool": "warning_statistics",
"params": {
"start_time": "2024-01-01 00:00:00",
"end_time": "2024-06-30 23:59:59",
"region_name": "甘肃省",
}
}
}
]
# 为每个测试案例创建一个表格
for i, case in enumerate(test_cases, 1):
console.print(f"\n[bold cyan]=== 测试案例 {i} ===[/bold cyan]")
table = Table(title=f"查询: {case['query']}")
table.add_column("项目", style="cyan")
table.add_column("预期结果", style="green")
table.add_column("实际结果", style="yellow")
table.add_column("是否匹配", style="magenta")
try:
result = picker.pick(case["query"])
# 添加工具比较行
expected_tool = case["expected"]["tool"]
actual_tool = result["tool"]
table.add_row(
"选择的工具",
expected_tool,
actual_tool,
"✓" if expected_tool == actual_tool else "✗"
)
# 添加参数比较行
for param_key in case["expected"]["params"]:
expected_value = str(case["expected"]["params"][param_key])
actual_value = str(result["params"].get(param_key, "未提供"))
table.add_row(
f"参数: {param_key}",
expected_value,
actual_value,
"✓" if expected_value == actual_value else "✗"
)
tool = tool_dict[result["tool"]]
params = result["params"]
result = tool.invoke(params)
print(result)
except Exception as e:
table.add_row("错误", "", str(e), "✗")
console.print(table)
console.print("=" * 80)
if __name__ == "__main__":
run_examples()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment