Commit 23a22fdc by tinywell

接口根据实际修改

parent f1fb0cb1
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import random
class FakeDataGenerator:
"""生成用于测试在线率分析工具的模拟数据"""
def __init__(self):
# 省份列表
self.provinces = [
'北京市', '上海市', '广东省', '江苏省', '浙江省',
'山东省', '河南省', '四川省', '湖北省', '福建省'
]
# 制造商列表
self.manufacturers = [
'华为', '中兴', '烽火', '诺基亚', '爱立信',
'思科', '新华三', '锐捷', '迈普', '东方通信'
]
# 基础在线率范围
self.base_rate_range = (0.85, 0.98)
def generate_region_data(self, region_name: str, start_time: str, end_time: str) -> pd.DataFrame:
"""生成指定地区和时间段的在线率数据"""
start_date = datetime.strptime(start_time, '%Y-%m-%d')
end_date = datetime.strptime(end_time, '%Y-%m-%d')
date_range = pd.date_range(start_date, end_date, freq='D')
data = []
base_rate = random.uniform(*self.base_rate_range)
for date in date_range:
# 添加一些随机波动
daily_rate = min(1.0, max(0.0, base_rate + random.uniform(-0.05, 0.05)))
data.append({
'date': date,
'region': region_name,
'online_rate': daily_rate,
'device_count': random.randint(1000, 5000)
})
return pd.DataFrame(data)
def generate_ranking_data(self, rank_type: int) -> pd.DataFrame:
"""生成排名数据"""
if rank_type == 1: # 省份排名
entities = self.provinces
else: # 厂商排名
entities = self.manufacturers
data = []
for entity in entities:
base_rate = random.uniform(*self.base_rate_range)
data.append({
'name': entity,
'online_rate': base_rate,
'device_count': random.randint(5000, 20000),
'offline_count': random.randint(100, 1000)
})
return pd.DataFrame(data).sort_values('online_rate', ascending=False)
def generate_national_trend(self, start_time: str, end_time: str) -> pd.DataFrame:
"""生成全国在线率趋势数据"""
start_date = datetime.strptime(start_time, '%Y-%m-%d')
end_date = datetime.strptime(end_time, '%Y-%m-%d')
date_range = pd.date_range(start_date, end_date, freq='D')
data = []
base_rate = random.uniform(*self.base_rate_range)
trend = np.linspace(-0.02, 0.02, len(date_range)) # 添加轻微的趋势
for i, date in enumerate(date_range):
# 基础在线率 + 趋势 + 随机波动
daily_rate = min(1.0, max(0.0, base_rate + trend[i] + random.uniform(-0.02, 0.02)))
data.append({
'date': date,
'online_rate': daily_rate,
'total_devices': random.randint(50000, 100000),
'online_devices': random.randint(40000, 90000)
})
return pd.DataFrame(data)
class MockDBConnection:
"""模拟数据库连接类"""
def __init__(self):
self.fake_data = FakeDataGenerator()
def query(self, sql: str, params: dict = None) -> pd.DataFrame:
"""模拟SQL查询"""
# 根据SQL语句特征返回相应的模拟数据
if 'region' in sql.lower():
return self.fake_data.generate_region_data(
params.get('region_name', '北京市'),
params.get('start_time', '2024-01-01'),
params.get('end_time', '2024-01-07')
)
elif 'rank' in sql.lower():
return self.fake_data.generate_ranking_data(
params.get('type', 1)
)
elif 'national' in sql.lower():
return self.fake_data.generate_national_trend(
params.get('start_time', '2024-01-01'),
params.get('end_time', '2024-01-07')
)
else:
return pd.DataFrame() # 默认返回空数据框
......@@ -7,6 +7,11 @@ from urllib.parse import urljoin
# 泛型类型定义
T = TypeVar('T')
const_base_url = "http://172.30.0.37:30007"
const_url_point = "/cigem/getMonitorPointAll"
const_url_rate = "/cigem/getAvgOnlineRate"
const_url_rate_ranking = "/cigem/getOnlineRateRank"
class BaseResponse(BaseModel, Generic[T]):
"""通用响应模型"""
type: int
......@@ -16,7 +21,7 @@ class BaseResponse(BaseModel, Generic[T]):
class BaseHttpClient:
"""基础HTTP客户端"""
def __init__(self, base_url: str = "http://localhost:5001"):
def __init__(self, base_url: str = const_base_url):
self.base_url = base_url.rstrip('/')
self.timeout = 30.0
......@@ -57,7 +62,7 @@ class MonitorClient(BaseHttpClient):
"""异步查询监测点信息"""
data = await self._request_async(
"POST",
"/api/monitor/points",
const_url_point,
json={"key": key}
)
return BaseResponse[List[MonitorPoint]](**data)
......@@ -66,7 +71,7 @@ class MonitorClient(BaseHttpClient):
"""同步查询监测点信息"""
data = self._request_sync(
"POST",
"/api/monitor/points",
const_url_point,
json={"key": key}
)
return BaseResponse[List[MonitorPoint]](**data)
......@@ -78,7 +83,7 @@ class RateClient(BaseHttpClient):
"""异步查询在线率信息"""
data = await self._request_async(
"POST",
"/api/device/rate",
const_url_rate,
json={
'areaCode': areacode,
'startDate': startDate,
......@@ -91,7 +96,7 @@ class RateClient(BaseHttpClient):
"""同步查询在线率信息"""
data = self._request_sync(
"POST",
"/api/device/rate",
const_url_rate,
json={
'areaCode': areacode,
'startDate': startDate,
......@@ -104,7 +109,7 @@ class RateClient(BaseHttpClient):
"""同步查询在线率排名信息"""
data = self._request_sync(
"POST",
"/api/device/rate/ranking",
const_url_rate_ranking,
json={'type': rank_type}
)
return BaseResponse[Dict](**data)
......@@ -113,7 +118,7 @@ class RateClient(BaseHttpClient):
"""异步查询在线率排名信息"""
data = await self._request_async(
"POST",
"/api/device/rate/ranking",
const_url_rate_ranking,
json={'type': rank_type}
)
return BaseResponse[Dict](**data)
......
......@@ -7,17 +7,15 @@ from pydantic import BaseModel, Field
from typing import Type
from langchain_core.tools import BaseTool
from .http_tools import RateClient
from .http_tools import RateClient, const_base_url
from .code import AreaCodeTool
code_tool = AreaCodeTool()
class BaseRateTool(BaseTool):
"""设备在线率分析基础工具类"""
db: Any = Field(None, exclude=True)
def __init__(self, db_connection, **data):
def __init__(self, **data):
super().__init__(**data)
self.db = db_connection
def format_response(self, data: Dict[str, Any], chart: go.Figure) -> Dict[str, Any]:
"""格式化返回结果"""
......@@ -75,7 +73,7 @@ class RegionRateTool(BaseRateTool):
args_schema: Type[BaseModel] = RegionRateArgs
client: Any = Field(None, exclude=True)
def __init__(self, base_url: str = "http://localhost:5001", **data):
def __init__(self, base_url: str = const_base_url, **data):
super().__init__(**data)
self.client = RateClient(base_url=base_url)
......@@ -118,74 +116,53 @@ class RankingRateTool(BaseRateTool):
name = "online_rate_ranking"
description = "查询设备在线率的排名数据,可查询省份排名或厂商排名"
args_schema: Type[BaseModel] = RankingRateArgs
client: Any = Field(None, exclude=True)
def __init__(self, base_url: str = const_base_url, **data):
super().__init__(**data)
self.client = RateClient(base_url=base_url)
def _run(self, rate_type: int) -> Dict[str, Any]:
return self.get_ranking_data(rate_type)
def get_ranking_data(self, rank_type: int) -> Dict[str, Any]:
if rank_type == 1:
def get_ranking_data(self, rate_type: int) -> Dict[str, Any]:
if rate_type == 1:
return self._get_province_ranking()
else:
return self._get_manufacturer_ranking()
def _get_province_ranking(self) -> Dict[str, Any]:
"""获取省份在线率排名"""
sql = """rank"""
df = self.db.query(sql, {'type': 1})
# 生成排名图表
fig = px.bar(df,
x='name',
y='online_rate',
title='各省份设备在线率排名')
fig.update_layout(
xaxis_title='省份',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
df = self.client.query_rates_ranking_sync(rank_type=1)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
data = {
'rankings': df.to_dict('records'),
'total_provinces': len(df),
'average_rate': df['online_rate'].mean(),
'rankings': df.resultdata,
'total_provinces': len(df.resultdata),
'best_province': {
'name': df.iloc[0]['name'],
'rate': df.iloc[0]['online_rate']
'name': df.resultdata[0]['name'],
'rate': df.resultdata[0]['onlineRate']
}
}
return self.format_response(data, fig)
return data
def _get_manufacturer_ranking(self) -> Dict[str, Any]:
"""获取厂商在线率排名"""
sql = """rank"""
df = self.db.query(sql, {'type': 2})
# 生成排名图表
fig = px.bar(df,
x='name',
y='online_rate',
title='各设备厂商在线率排名')
fig.update_layout(
xaxis_title='厂商',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
df = self.client.query_rates_ranking_sync(rank_type=2)
print("厂商数据:", df.resultdata)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
data = {
'rankings': df.to_dict('records'),
'total_manufacturers': len(df),
'average_rate': df['online_rate'].mean(),
'rankings': df.resultdata,
'total_manufacturers': len(df.resultdata),
'best_manufacturer': {
'name': df.iloc[0]['name'],
'rate': df.iloc[0]['online_rate']
'name': df.resultdata[0]['name'],
'rate': df.resultdata[0]['onlineRate']
}
}
return self.format_response(data, fig)
return data
class NationalTrendArgs(BaseModel):
"""全国趋势查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
......@@ -204,45 +181,11 @@ class NationalTrendTool(BaseRateTool):
"""
接口三:获取全国在线率趋势
"""
sql = """national"""
df = self.db.query(sql, {
'start_time': start_time,
'end_time': end_time
})
# 生成趋势图表
fig = go.Figure()
# 添加在线率曲线
fig.add_trace(go.Scatter(
x=df['date'],
y=df['online_rate'],
name='在线率',
mode='lines+markers'
))
# 设置图表布局
fig.update_layout(
title='全国设备在线率趋势',
xaxis_title='日期',
yaxis_title='在线率',
yaxis_tickformat='.2%'
)
# 准备数据
data = {
'trend_data': df.to_dict('records'),
'statistics': {
'average_rate': df['online_rate'].mean(),
'max_rate': df['online_rate'].max(),
'min_rate': df['online_rate'].min(),
'average_devices': int(df['total_devices'].mean()),
'average_online': int(df['online_devices'].mean())
},
'date_range': {
'start': start_time,
'end': end_time
}
}
return self.format_response(data, fig)
\ No newline at end of file
import sys
import argparse
sys.path.append('../')
from fastapi import FastAPI, Header
......@@ -17,12 +18,8 @@ app.add_middleware(
allow_headers=["*"], # 允许所有HTTP头
)
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B',
verbose=True
)
global base_llm
base_llm = None
@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
......@@ -42,4 +39,19 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
}
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8088)
# 参数解析
parser = argparse.ArgumentParser(description="启动API服务")
parser.add_argument("--port", type=int, default=8088, help="API服务端口")
parser.add_argument("--host", type=str, default='0.0.0.0', help="API服务地址")
parser.add_argument("--llm", type=str, default='Qwen2-7B', help="API服务地址")
parser.add_argument("--api_base", type=str, default='http://192.168.10.14:8000/v1', help="API服务地址")
args = parser.parse_args()
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base=args.api_base,
model_name=args.llm,
verbose=True
)
uvicorn.run(app, host=args.host, port=args.port)
......@@ -113,7 +113,7 @@ class QueryResponse(BaseModel):
message: str = ""
resultdata: List[MonitorPoint]
@app.post("/api/monitor/points", response_model=QueryResponse)
@app.post("/cigem/getMonitorPointAll", response_model=QueryResponse)
async def query_points(request: QueryRequest):
"""检测点查询接口"""
print(f"进入 query_points 接口, 查询监测点信息: {request.key}")
......@@ -153,7 +153,7 @@ class DeviceRateRequest(BaseModel):
class DeviceRateItem(BaseModel):
name: str
rate: str
rate: float
class DeviceRateResponse(BaseModel):
type: int = 1
......@@ -200,12 +200,12 @@ def generate_rate_data(area_code: str) -> List[DeviceRateItem]:
result_data.append(DeviceRateItem(
name=sub_area["name"],
rate=f"{rate:.2f}"
rate=rate
))
return result_data
@app.post("/api/device/rate", response_model=DeviceRateResponse)
@app.post("/cigem/getAvgOnlineRate", response_model=DeviceRateResponse)
async def query_device_rate(request: DeviceRateRequest):
"""查询不同时间段不同地区设备在线率"""
print(f"进入 query_device_rate 接口, 查询参数: {request}")
......@@ -318,7 +318,7 @@ def generate_rate_ranking_data_by_province() -> List[RankingItem]:
result_data.sort(key=lambda x: x.onlineRate, reverse=True)
return result_data
@app.post("/api/device/rate/ranking", response_model=RankingResponse)
@app.post("/cigem/getOnlineRateRank", response_model=RankingResponse)
async def query_device_rate_ranking(request: DeviceRateRankingRequest):
"""
查询设备在线率排名
......
......@@ -15,7 +15,6 @@ from langchain import hub
from src.agent.tool_rate import RegionRateTool,RankingRateTool,NationalTrendTool
from src.agent.fake_data_rate import MockDBConnection
from src.agent.tool_monitor import MonitorPointTool
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
......@@ -68,20 +67,13 @@ class RateAgent:
# 适配 structured_chat_agent 的 prompt
ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及监测设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。
你可以处理以下三类核心任务:
1. 地区在线率分析:分析指定地区(省/市/区县)在特定时间段的设备在线率
2. 在线率排名分析:分析各省份或各厂商的在线率排名情况
3. 全国趋势分析:分析全国范围内在线率随时间的变化趋势
4. 监测点信息查询:查询指定地区的监测点信息
你需要:
1. 理解用户意图,将用户问题映射到合适的分析类型
2. 确保必要参数完整,如果缺少参数则提示用户缺少参数
3. 如果参数完整,则调用相应的分析工具获取数据
4. 生成清晰的分析报告,包括数据解读和markdown 格式的数据表格
5. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议
4. 生成清晰的分析报告,包括数据解读
5. 工具返回的数据务必用 markdown 格式的数据表格进行展示
6. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议
注意事项:
- 时间格式统一使用:YYYY-MM-DD
......@@ -151,11 +143,10 @@ class RateAgentV2:
def new_rate_agent(llm, verbose: bool = False,**args):
conn = MockDBConnection()
tools = [
RegionRateTool(db_connection=conn),
RankingRateTool(db_connection=conn),
NationalTrendTool(db_connection=conn),
RegionRateTool(),
RankingRateTool(),
NationalTrendTool(),
MonitorPointTool()
]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment