Commit 23a22fdc by tinywell

接口根据实际修改

parent f1fb0cb1
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import random
class FakeDataGenerator:
"""生成用于测试在线率分析工具的模拟数据"""
def __init__(self):
# 省份列表
self.provinces = [
'北京市', '上海市', '广东省', '江苏省', '浙江省',
'山东省', '河南省', '四川省', '湖北省', '福建省'
]
# 制造商列表
self.manufacturers = [
'华为', '中兴', '烽火', '诺基亚', '爱立信',
'思科', '新华三', '锐捷', '迈普', '东方通信'
]
# 基础在线率范围
self.base_rate_range = (0.85, 0.98)
def generate_region_data(self, region_name: str, start_time: str, end_time: str) -> pd.DataFrame:
"""生成指定地区和时间段的在线率数据"""
start_date = datetime.strptime(start_time, '%Y-%m-%d')
end_date = datetime.strptime(end_time, '%Y-%m-%d')
date_range = pd.date_range(start_date, end_date, freq='D')
data = []
base_rate = random.uniform(*self.base_rate_range)
for date in date_range:
# 添加一些随机波动
daily_rate = min(1.0, max(0.0, base_rate + random.uniform(-0.05, 0.05)))
data.append({
'date': date,
'region': region_name,
'online_rate': daily_rate,
'device_count': random.randint(1000, 5000)
})
return pd.DataFrame(data)
def generate_ranking_data(self, rank_type: int) -> pd.DataFrame:
"""生成排名数据"""
if rank_type == 1: # 省份排名
entities = self.provinces
else: # 厂商排名
entities = self.manufacturers
data = []
for entity in entities:
base_rate = random.uniform(*self.base_rate_range)
data.append({
'name': entity,
'online_rate': base_rate,
'device_count': random.randint(5000, 20000),
'offline_count': random.randint(100, 1000)
})
return pd.DataFrame(data).sort_values('online_rate', ascending=False)
def generate_national_trend(self, start_time: str, end_time: str) -> pd.DataFrame:
"""生成全国在线率趋势数据"""
start_date = datetime.strptime(start_time, '%Y-%m-%d')
end_date = datetime.strptime(end_time, '%Y-%m-%d')
date_range = pd.date_range(start_date, end_date, freq='D')
data = []
base_rate = random.uniform(*self.base_rate_range)
trend = np.linspace(-0.02, 0.02, len(date_range)) # 添加轻微的趋势
for i, date in enumerate(date_range):
# 基础在线率 + 趋势 + 随机波动
daily_rate = min(1.0, max(0.0, base_rate + trend[i] + random.uniform(-0.02, 0.02)))
data.append({
'date': date,
'online_rate': daily_rate,
'total_devices': random.randint(50000, 100000),
'online_devices': random.randint(40000, 90000)
})
return pd.DataFrame(data)
class MockDBConnection:
"""模拟数据库连接类"""
def __init__(self):
self.fake_data = FakeDataGenerator()
def query(self, sql: str, params: dict = None) -> pd.DataFrame:
"""模拟SQL查询"""
# 根据SQL语句特征返回相应的模拟数据
if 'region' in sql.lower():
return self.fake_data.generate_region_data(
params.get('region_name', '北京市'),
params.get('start_time', '2024-01-01'),
params.get('end_time', '2024-01-07')
)
elif 'rank' in sql.lower():
return self.fake_data.generate_ranking_data(
params.get('type', 1)
)
elif 'national' in sql.lower():
return self.fake_data.generate_national_trend(
params.get('start_time', '2024-01-01'),
params.get('end_time', '2024-01-07')
)
else:
return pd.DataFrame() # 默认返回空数据框
...@@ -7,6 +7,11 @@ from urllib.parse import urljoin ...@@ -7,6 +7,11 @@ from urllib.parse import urljoin
# 泛型类型定义 # 泛型类型定义
T = TypeVar('T') T = TypeVar('T')
const_base_url = "http://172.30.0.37:30007"
const_url_point = "/cigem/getMonitorPointAll"
const_url_rate = "/cigem/getAvgOnlineRate"
const_url_rate_ranking = "/cigem/getOnlineRateRank"
class BaseResponse(BaseModel, Generic[T]): class BaseResponse(BaseModel, Generic[T]):
"""通用响应模型""" """通用响应模型"""
type: int type: int
...@@ -16,7 +21,7 @@ class BaseResponse(BaseModel, Generic[T]): ...@@ -16,7 +21,7 @@ class BaseResponse(BaseModel, Generic[T]):
class BaseHttpClient: class BaseHttpClient:
"""基础HTTP客户端""" """基础HTTP客户端"""
def __init__(self, base_url: str = "http://localhost:5001"): def __init__(self, base_url: str = const_base_url):
self.base_url = base_url.rstrip('/') self.base_url = base_url.rstrip('/')
self.timeout = 30.0 self.timeout = 30.0
...@@ -57,7 +62,7 @@ class MonitorClient(BaseHttpClient): ...@@ -57,7 +62,7 @@ class MonitorClient(BaseHttpClient):
"""异步查询监测点信息""" """异步查询监测点信息"""
data = await self._request_async( data = await self._request_async(
"POST", "POST",
"/api/monitor/points", const_url_point,
json={"key": key} json={"key": key}
) )
return BaseResponse[List[MonitorPoint]](**data) return BaseResponse[List[MonitorPoint]](**data)
...@@ -66,7 +71,7 @@ class MonitorClient(BaseHttpClient): ...@@ -66,7 +71,7 @@ class MonitorClient(BaseHttpClient):
"""同步查询监测点信息""" """同步查询监测点信息"""
data = self._request_sync( data = self._request_sync(
"POST", "POST",
"/api/monitor/points", const_url_point,
json={"key": key} json={"key": key}
) )
return BaseResponse[List[MonitorPoint]](**data) return BaseResponse[List[MonitorPoint]](**data)
...@@ -78,7 +83,7 @@ class RateClient(BaseHttpClient): ...@@ -78,7 +83,7 @@ class RateClient(BaseHttpClient):
"""异步查询在线率信息""" """异步查询在线率信息"""
data = await self._request_async( data = await self._request_async(
"POST", "POST",
"/api/device/rate", const_url_rate,
json={ json={
'areaCode': areacode, 'areaCode': areacode,
'startDate': startDate, 'startDate': startDate,
...@@ -91,7 +96,7 @@ class RateClient(BaseHttpClient): ...@@ -91,7 +96,7 @@ class RateClient(BaseHttpClient):
"""同步查询在线率信息""" """同步查询在线率信息"""
data = self._request_sync( data = self._request_sync(
"POST", "POST",
"/api/device/rate", const_url_rate,
json={ json={
'areaCode': areacode, 'areaCode': areacode,
'startDate': startDate, 'startDate': startDate,
...@@ -104,7 +109,7 @@ class RateClient(BaseHttpClient): ...@@ -104,7 +109,7 @@ class RateClient(BaseHttpClient):
"""同步查询在线率排名信息""" """同步查询在线率排名信息"""
data = self._request_sync( data = self._request_sync(
"POST", "POST",
"/api/device/rate/ranking", const_url_rate_ranking,
json={'type': rank_type} json={'type': rank_type}
) )
return BaseResponse[Dict](**data) return BaseResponse[Dict](**data)
...@@ -113,7 +118,7 @@ class RateClient(BaseHttpClient): ...@@ -113,7 +118,7 @@ class RateClient(BaseHttpClient):
"""异步查询在线率排名信息""" """异步查询在线率排名信息"""
data = await self._request_async( data = await self._request_async(
"POST", "POST",
"/api/device/rate/ranking", const_url_rate_ranking,
json={'type': rank_type} json={'type': rank_type}
) )
return BaseResponse[Dict](**data) return BaseResponse[Dict](**data)
......
...@@ -7,17 +7,15 @@ from pydantic import BaseModel, Field ...@@ -7,17 +7,15 @@ from pydantic import BaseModel, Field
from typing import Type from typing import Type
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from .http_tools import RateClient from .http_tools import RateClient, const_base_url
from .code import AreaCodeTool from .code import AreaCodeTool
code_tool = AreaCodeTool() code_tool = AreaCodeTool()
class BaseRateTool(BaseTool): class BaseRateTool(BaseTool):
"""设备在线率分析基础工具类""" """设备在线率分析基础工具类"""
db: Any = Field(None, exclude=True)
def __init__(self, db_connection, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
self.db = db_connection
def format_response(self, data: Dict[str, Any], chart: go.Figure) -> Dict[str, Any]: def format_response(self, data: Dict[str, Any], chart: go.Figure) -> Dict[str, Any]:
"""格式化返回结果""" """格式化返回结果"""
...@@ -75,7 +73,7 @@ class RegionRateTool(BaseRateTool): ...@@ -75,7 +73,7 @@ class RegionRateTool(BaseRateTool):
args_schema: Type[BaseModel] = RegionRateArgs args_schema: Type[BaseModel] = RegionRateArgs
client: Any = Field(None, exclude=True) client: Any = Field(None, exclude=True)
def __init__(self, base_url: str = "http://localhost:5001", **data): def __init__(self, base_url: str = const_base_url, **data):
super().__init__(**data) super().__init__(**data)
self.client = RateClient(base_url=base_url) self.client = RateClient(base_url=base_url)
...@@ -118,74 +116,53 @@ class RankingRateTool(BaseRateTool): ...@@ -118,74 +116,53 @@ class RankingRateTool(BaseRateTool):
name = "online_rate_ranking" name = "online_rate_ranking"
description = "查询设备在线率的排名数据,可查询省份排名或厂商排名" description = "查询设备在线率的排名数据,可查询省份排名或厂商排名"
args_schema: Type[BaseModel] = RankingRateArgs args_schema: Type[BaseModel] = RankingRateArgs
client: Any = Field(None, exclude=True)
def __init__(self, base_url: str = const_base_url, **data):
super().__init__(**data)
self.client = RateClient(base_url=base_url)
def _run(self, rate_type: int) -> Dict[str, Any]: def _run(self, rate_type: int) -> Dict[str, Any]:
return self.get_ranking_data(rate_type) return self.get_ranking_data(rate_type)
def get_ranking_data(self, rank_type: int) -> Dict[str, Any]: def get_ranking_data(self, rate_type: int) -> Dict[str, Any]:
if rank_type == 1: if rate_type == 1:
return self._get_province_ranking() return self._get_province_ranking()
else: else:
return self._get_manufacturer_ranking() return self._get_manufacturer_ranking()
def _get_province_ranking(self) -> Dict[str, Any]: def _get_province_ranking(self) -> Dict[str, Any]:
"""获取省份在线率排名""" """获取省份在线率排名"""
sql = """rank""" df = self.client.query_rates_ranking_sync(rank_type=1)
df = self.db.query(sql, {'type': 1}) # df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 生成排名图表
fig = px.bar(df,
x='name',
y='online_rate',
title='各省份设备在线率排名')
fig.update_layout(
xaxis_title='省份',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据 # 准备数据
data = { data = {
'rankings': df.to_dict('records'), 'rankings': df.resultdata,
'total_provinces': len(df), 'total_provinces': len(df.resultdata),
'average_rate': df['online_rate'].mean(),
'best_province': { 'best_province': {
'name': df.iloc[0]['name'], 'name': df.resultdata[0]['name'],
'rate': df.iloc[0]['online_rate'] 'rate': df.resultdata[0]['onlineRate']
} }
} }
return self.format_response(data, fig) return data
def _get_manufacturer_ranking(self) -> Dict[str, Any]: def _get_manufacturer_ranking(self) -> Dict[str, Any]:
"""获取厂商在线率排名""" """获取厂商在线率排名"""
sql = """rank""" df = self.client.query_rates_ranking_sync(rank_type=2)
df = self.db.query(sql, {'type': 2}) print("厂商数据:", df.resultdata)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 生成排名图表
fig = px.bar(df,
x='name',
y='online_rate',
title='各设备厂商在线率排名')
fig.update_layout(
xaxis_title='厂商',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据 # 准备数据
data = { data = {
'rankings': df.to_dict('records'), 'rankings': df.resultdata,
'total_manufacturers': len(df), 'total_manufacturers': len(df.resultdata),
'average_rate': df['online_rate'].mean(),
'best_manufacturer': { 'best_manufacturer': {
'name': df.iloc[0]['name'], 'name': df.resultdata[0]['name'],
'rate': df.iloc[0]['online_rate'] 'rate': df.resultdata[0]['onlineRate']
} }
} }
return self.format_response(data, fig) return data
class NationalTrendArgs(BaseModel): class NationalTrendArgs(BaseModel):
"""全国趋势查询参数""" """全国趋势查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)") start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
...@@ -204,45 +181,11 @@ class NationalTrendTool(BaseRateTool): ...@@ -204,45 +181,11 @@ class NationalTrendTool(BaseRateTool):
""" """
接口三:获取全国在线率趋势 接口三:获取全国在线率趋势
""" """
sql = """national"""
df = self.db.query(sql, {
'start_time': start_time,
'end_time': end_time
})
# 生成趋势图表
fig = go.Figure()
# 添加在线率曲线
fig.add_trace(go.Scatter(
x=df['date'],
y=df['online_rate'],
name='在线率',
mode='lines+markers'
))
# 设置图表布局
fig.update_layout(
title='全国设备在线率趋势',
xaxis_title='日期',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据 # 准备数据
data = { data = {
'trend_data': df.to_dict('records'),
'statistics': {
'average_rate': df['online_rate'].mean(),
'max_rate': df['online_rate'].max(),
'min_rate': df['online_rate'].min(),
'average_devices': int(df['total_devices'].mean()),
'average_online': int(df['online_devices'].mean())
},
'date_range': {
'start': start_time,
'end': end_time
}
} }
return self.format_response(data, fig) return self.format_response(data, fig)
\ No newline at end of file
import sys import sys
import argparse
sys.path.append('../') sys.path.append('../')
from fastapi import FastAPI, Header from fastapi import FastAPI, Header
...@@ -17,12 +18,8 @@ app.add_middleware( ...@@ -17,12 +18,8 @@ app.add_middleware(
allow_headers=["*"], # 允许所有HTTP头 allow_headers=["*"], # 允许所有HTTP头
) )
base_llm = ChatOpenAI( global base_llm
openai_api_key='xxxxxxxxxxxxx', base_llm = None
openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B',
verbose=True
)
@app.post('/api/agent/rate') @app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)): def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
...@@ -42,4 +39,19 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)): ...@@ -42,4 +39,19 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
} }
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8088) # 参数解析
parser = argparse.ArgumentParser(description="启动API服务")
parser.add_argument("--port", type=int, default=8088, help="API服务端口")
parser.add_argument("--host", type=str, default='0.0.0.0', help="API服务地址")
parser.add_argument("--llm", type=str, default='Qwen2-7B', help="API服务地址")
parser.add_argument("--api_base", type=str, default='http://192.168.10.14:8000/v1', help="API服务地址")
args = parser.parse_args()
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base=args.api_base,
model_name=args.llm,
verbose=True
)
uvicorn.run(app, host=args.host, port=args.port)
...@@ -113,7 +113,7 @@ class QueryResponse(BaseModel): ...@@ -113,7 +113,7 @@ class QueryResponse(BaseModel):
message: str = "" message: str = ""
resultdata: List[MonitorPoint] resultdata: List[MonitorPoint]
@app.post("/api/monitor/points", response_model=QueryResponse) @app.post("/cigem/getMonitorPointAll", response_model=QueryResponse)
async def query_points(request: QueryRequest): async def query_points(request: QueryRequest):
"""检测点查询接口""" """检测点查询接口"""
print(f"进入 query_points 接口, 查询监测点信息: {request.key}") print(f"进入 query_points 接口, 查询监测点信息: {request.key}")
...@@ -153,7 +153,7 @@ class DeviceRateRequest(BaseModel): ...@@ -153,7 +153,7 @@ class DeviceRateRequest(BaseModel):
class DeviceRateItem(BaseModel): class DeviceRateItem(BaseModel):
name: str name: str
rate: str rate: float
class DeviceRateResponse(BaseModel): class DeviceRateResponse(BaseModel):
type: int = 1 type: int = 1
...@@ -200,12 +200,12 @@ def generate_rate_data(area_code: str) -> List[DeviceRateItem]: ...@@ -200,12 +200,12 @@ def generate_rate_data(area_code: str) -> List[DeviceRateItem]:
result_data.append(DeviceRateItem( result_data.append(DeviceRateItem(
name=sub_area["name"], name=sub_area["name"],
rate=f"{rate:.2f}" rate=rate
)) ))
return result_data return result_data
@app.post("/api/device/rate", response_model=DeviceRateResponse) @app.post("/cigem/getAvgOnlineRate", response_model=DeviceRateResponse)
async def query_device_rate(request: DeviceRateRequest): async def query_device_rate(request: DeviceRateRequest):
"""查询不同时间段不同地区设备在线率""" """查询不同时间段不同地区设备在线率"""
print(f"进入 query_device_rate 接口, 查询参数: {request}") print(f"进入 query_device_rate 接口, 查询参数: {request}")
...@@ -318,7 +318,7 @@ def generate_rate_ranking_data_by_province() -> List[RankingItem]: ...@@ -318,7 +318,7 @@ def generate_rate_ranking_data_by_province() -> List[RankingItem]:
result_data.sort(key=lambda x: x.onlineRate, reverse=True) result_data.sort(key=lambda x: x.onlineRate, reverse=True)
return result_data return result_data
@app.post("/api/device/rate/ranking", response_model=RankingResponse) @app.post("/cigem/getOnlineRateRank", response_model=RankingResponse)
async def query_device_rate_ranking(request: DeviceRateRankingRequest): async def query_device_rate_ranking(request: DeviceRateRankingRequest):
""" """
查询设备在线率排名 查询设备在线率排名
......
...@@ -15,7 +15,6 @@ from langchain import hub ...@@ -15,7 +15,6 @@ from langchain import hub
from src.agent.tool_rate import RegionRateTool,RankingRateTool,NationalTrendTool from src.agent.tool_rate import RegionRateTool,RankingRateTool,NationalTrendTool
from src.agent.fake_data_rate import MockDBConnection
from src.agent.tool_monitor import MonitorPointTool from src.agent.tool_monitor import MonitorPointTool
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None, # def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
...@@ -68,20 +67,13 @@ class RateAgent: ...@@ -68,20 +67,13 @@ class RateAgent:
# 适配 structured_chat_agent 的 prompt # 适配 structured_chat_agent 的 prompt
ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及监测设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。 ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及监测设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。
你可以处理以下三类核心任务:
1. 地区在线率分析:分析指定地区(省/市/区县)在特定时间段的设备在线率
2. 在线率排名分析:分析各省份或各厂商的在线率排名情况
3. 全国趋势分析:分析全国范围内在线率随时间的变化趋势
4. 监测点信息查询:查询指定地区的监测点信息
你需要: 你需要:
1. 理解用户意图,将用户问题映射到合适的分析类型 1. 理解用户意图,将用户问题映射到合适的分析类型
2. 确保必要参数完整,如果缺少参数则提示用户缺少参数 2. 确保必要参数完整,如果缺少参数则提示用户缺少参数
3. 如果参数完整,则调用相应的分析工具获取数据 3. 如果参数完整,则调用相应的分析工具获取数据
4. 生成清晰的分析报告,包括数据解读和markdown 格式的数据表格 4. 生成清晰的分析报告,包括数据解读
5. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议 5. 工具返回的数据务必用 markdown 格式的数据表格进行展示
6. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议
注意事项: 注意事项:
- 时间格式统一使用:YYYY-MM-DD - 时间格式统一使用:YYYY-MM-DD
...@@ -151,11 +143,10 @@ class RateAgentV2: ...@@ -151,11 +143,10 @@ class RateAgentV2:
def new_rate_agent(llm, verbose: bool = False,**args): def new_rate_agent(llm, verbose: bool = False,**args):
conn = MockDBConnection()
tools = [ tools = [
RegionRateTool(db_connection=conn), RegionRateTool(),
RankingRateTool(db_connection=conn), RankingRateTool(),
NationalTrendTool(db_connection=conn), NationalTrendTool(),
MonitorPointTool() MonitorPointTool()
] ]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment