Commit f50cc798 by tinywell

增加在线率接口、工具调用

parent 6a19aee9
This source diff could not be displayed because it is too large. You can view the blob instead.
import pandas as pd
from typing import List, Optional, Dict, Tuple
import os
class AreaCodeTool:
def __init__(self, csv_path: str = None):
"""
初始化行政区划代码工具
Args:
csv_path: CSV文件路径,如果为None则使用默认路径
"""
if csv_path is None:
# 获取当前文件所在目录
current_dir = os.path.dirname(os.path.abspath(__file__))
csv_path = os.path.join(current_dir, "area_code.csv")
# 读取CSV文件
self.df = pd.read_csv(csv_path, dtype={'code': str})
# 确保code列为字符串类型
self.df['code'] = self.df['code'].astype(str)
# 构建区域名称到代码的映射
self._build_name_maps()
def _build_name_maps(self):
"""构建区域名称到代码的映射"""
self.full_name_map = dict(zip(self.df['name'], self.df['code']))
# 构建省级映射
self.province_map = {}
# 构建市级映射
self.city_map = {}
# 构建区县级映射
self.district_map = {}
for _, row in self.df.iterrows():
name = row['name'].strip()
code = row['code']
parts = name.split('省' if '省' in name else '市')
if '省' in name:
province = parts[0] + '省'
self.province_map[province] = code
if len(parts) > 1 and parts[1]:
city_parts = parts[1].split('市')
if city_parts[0]:
city = city_parts[0] + '市'
self.city_map[city] = code
if len(city_parts) > 1 and city_parts[1]:
district = city_parts[1]
self.district_map[district] = code
else:
# 处理直辖市等特殊情况
if parts[0]:
self.city_map[parts[0] + '市'] = code
def find_code(self, area_name: str) -> List[Tuple[str, str]]:
"""
查找区域代码
Args:
area_name: 区域名称,可以是完整或部分名称
Returns:
List[Tuple[str, str]]: 返回匹配的(区域名称, 代码)列表
"""
results = []
# 尝试完整匹配
if area_name in self.full_name_map:
results.append((area_name, self.full_name_map[area_name]))
return results
# 尝试省级匹配
if area_name.endswith('省') and area_name in self.province_map:
results.append((area_name, self.province_map[area_name]))
# 尝试市级匹配
if area_name.endswith('市') and area_name in self.city_map:
results.append((area_name, self.city_map[area_name]))
# 尝试区县级匹配
if area_name in self.district_map:
results.append((area_name, self.district_map[area_name]))
# 模糊匹配
if not results:
mask = self.df['name'].str.contains(area_name, na=False)
matches = self.df[mask]
results.extend([(row['name'], row['code']) for _, row in matches.iterrows()])
return results
def get_full_name(self, code: str) -> Optional[str]:
"""
根据代码获取完整的区域名称
Args:
code: 区域代码
Returns:
Optional[str]: 完整的区域名称,如果未找到则返回None
"""
mask = self.df['code'] == code
matches = self.df[mask]
if not matches.empty:
return matches.iloc[0]['name']
return None
# 使用示例
def example_usage():
tool = AreaCodeTool()
# 测试不同类型的查询
test_cases = [
"安徽省",
"安庆市",
"迎江区",
"安徽省安庆市",
"安徽省安庆市迎江区",
"安庆" # 模糊查询
]
for query in test_cases:
print(f"\n查询: {query}")
results = tool.find_code(query)
for name, code in results:
print(f"匹配结果: {name} -> {code}")
# 测试代码反查
code = "340802" # 安徽省安庆市迎江区
full_name = tool.get_full_name(code)
if full_name:
print(f"\n代码反查: {code} -> {full_name}")
if __name__ == "__main__":
example_usage()
import httpx
from typing import List, Optional, Dict, TypeVar, Generic, Any
from pydantic import BaseModel
import asyncio
from urllib.parse import urljoin
# 泛型类型定义
T = TypeVar('T')
class BaseResponse(BaseModel, Generic[T]):
"""通用响应模型"""
type: int
resultcode: int
message: str
resultdata: T
class BaseHttpClient:
"""基础HTTP客户端"""
def __init__(self, base_url: str = "http://localhost:5001"):
self.base_url = base_url.rstrip('/')
self.timeout = 30.0
async def _request_async(self, method: str, endpoint: str, **kwargs) -> Any:
"""通用异步请求方法"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
url = urljoin(self.base_url, endpoint)
response = await client.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
def _request_sync(self, method: str, endpoint: str, **kwargs) -> Any:
"""通用同步请求方法"""
with httpx.Client(timeout=self.timeout) as client:
url = urljoin(self.base_url, endpoint)
response = client.request(method, url, **kwargs)
response.raise_for_status()
return response.json()
class MonitorPoint(BaseModel):
"""监测点数据模型"""
MONITORPOINTCODE: str
MONITORPOINTNAME: str
LOCATION: str
LATITUDE: str
LONGITUDE: str
ELEVATION: str
BUILDUNIT: str
MONITORUNIT: str
YWUNIT: str
SGDW: Optional[str] = None
MANUFACTURER: str = ""
class MonitorClient(BaseHttpClient):
"""监测点查询客户端"""
async def query_points(self, key: str) -> BaseResponse[List[MonitorPoint]]:
"""异步查询监测点信息"""
data = await self._request_async(
"POST",
"/api/monitor/points",
json={"key": key}
)
return BaseResponse[List[MonitorPoint]](**data)
def query_points_sync(self, key: str) -> BaseResponse[List[MonitorPoint]]:
"""同步查询监测点信息"""
data = self._request_sync(
"POST",
"/api/monitor/points",
json={"key": key}
)
return BaseResponse[List[MonitorPoint]](**data)
# 示例:添加新的数据接口客户端
class RateClient(BaseHttpClient):
"""在线率查询客户端"""
async def query_rates(self, params: Dict) -> BaseResponse[Dict]:
"""异步查询在线率信息"""
data = await self._request_async(
"POST",
"/api/device/rate",
json=params
)
return BaseResponse[Dict](**data)
def query_rates_sync(self, params: Dict) -> BaseResponse[Dict]:
"""同步查询在线率信息"""
data = self._request_sync(
"POST",
"/api/device/rate",
json=params
)
return BaseResponse[Dict](**data)
def query_rates_ranking_sync(self, params: Dict) -> BaseResponse[Dict]:
"""同步查询在线率排名信息"""
data = self._request_sync(
"POST",
"/api/device/rate/ranking",
json=params
)
return BaseResponse[Dict](**data)
async def query_rates_ranking(self, params: Dict) -> BaseResponse[Dict]:
"""异步查询在线率排名信息"""
data = await self._request_async(
"POST",
"/api/device/rate/ranking",
json=params
)
return BaseResponse[Dict](**data)
# 使用示例
async def example_async_usage():
# 监测点查询示例
monitor_client = MonitorClient()
try:
response = await monitor_client.query_points("湖南")
if response.resultcode == 1 and response.resultdata:
for point in response.resultdata:
print(f"监测点名称: {point.MONITORPOINTNAME}")
print(f"位置: {point.LOCATION}")
print("---")
except httpx.HTTPError as e:
print(f"HTTP请求错误: {e}")
# 在线率查询示例
rate_client = RateClient()
try:
response = await rate_client.query_rates({"region": "湖南", "type": "residential"})
if response.resultcode == 1:
print(f"在线率数据: {response.resultdata}")
except httpx.HTTPError as e:
print(f"HTTP请求错误: {e}")
if __name__ == "__main__":
asyncio.run(example_async_usage())
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional,List
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
from src.server.http_tools import MonitorClient
from .http_tools import MonitorClient, RateClient
class MonitorPointResponse():
"""监测点查询结果"""
status: str = Field(..., description="状态")
message: str = Field(..., description="消息")
points: List[Dict[str, Any]] = Field(..., description="监测点列表")
class MonitorPointArgs(BaseModel):
"""监测点查询参数"""
......@@ -45,11 +51,11 @@ class MonitorPointTool(BaseTool):
# 格式化返回结果
if not response.resultdata:
return {
"status": "success",
"message": f"未在{key}找到监测点信息",
"points": []
}
return MonitorPointResponse(
status="success",
message=f"未在{key}找到监测点信息",
points=[]
)
# 提取关键信息并格式化
points_info = []
......@@ -64,18 +70,18 @@ class MonitorPointTool(BaseTool):
"监测单位": point.MONITORUNIT
})
return {
"status": "success",
"message": f"在{key}找到{len(points_info)}个监测点",
"points": points_info
}
return MonitorPointResponse(
status="success",
message=f"在{key}找到{len(points_info)}个监测点",
points=points_info
)
except Exception as e:
return {
"status": "error",
"message": f"查询失败: {str(e)}",
"points": []
}
return MonitorPointResponse(
status="error",
message=f"查询失败: {str(e)}",
points=[]
)
def _arun(self, key: str) -> Dict[str, Any]:
"""
......
......@@ -7,6 +7,10 @@ from pydantic import BaseModel, Field
from typing import Type
from langchain_core.tools import BaseTool
from .http_tools import RateClient
from .code import AreaCodeTool
code_tool = AreaCodeTool()
class BaseRateTool(BaseTool):
"""设备在线率分析基础工具类"""
db: Any = Field(None, exclude=True)
......@@ -69,45 +73,41 @@ class RegionRateTool(BaseRateTool):
name = "region_online_rate"
description = "查询指定地区在指定时间段内的设备在线率"
args_schema: Type[BaseModel] = RegionRateArgs
client: Any = Field(None, exclude=True)
def __init__(self, base_url: str = "http://localhost:5001", **data):
super().__init__(**data)
self.client = RateClient(base_url=base_url)
def _run(self, region_name: str, start_time: str, end_time: str) -> Dict[str, Any]:
return self.get_region_online_rate(region_name, start_time, end_time)
def get_region_online_rate(self, region_name: str, start_time: str, end_time: str) -> Dict[str, Any]:
# 保持原有的实现...
# 查询数据
sql = """region"""
df = self.db.query(sql, {
'region_name': region_name,
'start_time': start_time,
'end_time': end_time
code = code_tool.find_code(region_name)
if not code:
return {
'code': 400,
'message': f'未找到匹配的区域代码: {region_name}'
}
df = self.client.query_rates_sync({
'startDate': start_time,
'endDate': end_time,
'areaCode': code[0][1]
})
# 生成图表
fig = px.line(df,
x='date',
y='online_rate',
title=f'{region_name}设备在线率趋势')
fig.update_layout(
xaxis_title='日期',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据
data = {
'region': region_name,
'average_rate': df['online_rate'].mean(),
'max_rate': df['online_rate'].max(),
'min_rate': df['online_rate'].min(),
'total_devices': df['device_count'].mean(),
'region_code': code[0][1],
'rate_data': df.resultdata,
'date_range': {
'start': start_time,
'end': end_time
}
}
return self.format_response(data, fig)
return data
class RankingRateArgs(BaseModel):
"""排名查询参数"""
......
import httpx
from typing import List, Optional, Dict
from pydantic import BaseModel
import asyncio
from urllib.parse import urljoin
class MonitorPoint(BaseModel):
MONITORPOINTCODE: str
MONITORPOINTNAME: str
LOCATION: str
LATITUDE: str
LONGITUDE: str
ELEVATION: str
BUILDUNIT: str
MONITORUNIT: str
YWUNIT: str
SGDW: Optional[str] = None
MANUFACTURER: str = ""
class QueryResponse(BaseModel):
type: int
resultcode: int
message: str
resultdata: List[MonitorPoint]
class MonitorClient:
def __init__(self, base_url: str = "http://localhost:5001"):
"""
初始化监测点查询客户端
Args:
base_url: API服务器基础URL
"""
self.base_url = base_url.rstrip('/')
self.timeout = 30.0
async def query_points(self, key: str) -> QueryResponse:
"""
异步查询监测点信息
Args:
key: 行政区划关键字(省/市/区县级别均可)
Returns:
QueryResponse: 查询响应对象
Raises:
httpx.HTTPError: 当HTTP请求失败时
ValueError: 当响应数据格式不正确时
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
url = urljoin(self.base_url, "/api/monitor/points")
response = await client.post(url, json={"key": key})
response.raise_for_status()
return QueryResponse(**response.json())
def query_points_sync(self, key: str) -> QueryResponse:
"""
同步查询监测点信息
Args:
key: 行政区划关键字(省/市/区县级别均可)
Returns:
QueryResponse: 查询响应对象
Raises:
httpx.HTTPError: 当HTTP请求失败时
ValueError: 当响应数据格式不正确时
"""
with httpx.Client(timeout=self.timeout) as client:
url = urljoin(self.base_url, "/api/monitor/points")
response = client.post(url, json={"key": key})
response.raise_for_status()
return QueryResponse(**response.json())
# 使用示例
async def example_async_usage():
client = MonitorClient()
try:
# 异步查询示例
response = await client.query_points("湖南")
if response.resultcode == 1 and response.resultdata:
for point in response.resultdata:
print(f"监测点名称: {point.MONITORPOINTNAME}")
print(f"位置: {point.LOCATION}")
print(f"经纬度: {point.LONGITUDE}, {point.LATITUDE}")
print("---")
except httpx.HTTPError as e:
print(f"HTTP请求错误: {e}")
except Exception as e:
print(f"发生错误: {e}")
def example_sync_usage():
client = MonitorClient()
try:
# 同步查询示例
response = client.query_points_sync("长沙")
if response.resultcode == 1 and response.resultdata:
for point in response.resultdata:
print(f"监测点名称: {point.MONITORPOINTNAME}")
print(f"位置: {point.LOCATION}")
print(f"经纬度: {point.LONGITUDE}, {point.LATITUDE}")
print("---")
except httpx.HTTPError as e:
print(f"HTTP请求错误: {e}")
except Exception as e:
print(f"发生错误: {e}")
if __name__ == "__main__":
# 异步调用示例
asyncio.run(example_async_usage())
# 同步调用示例
example_sync_usage()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment