Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
23a22fdc
Commit
23a22fdc
authored
Nov 08, 2024
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
接口根据实际修改
parent
f1fb0cb1
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
229 deletions
+69
-229
fake_data_rate.py
src/agent/fake_data_rate.py
+0
-112
http_tools.py
src/agent/http_tools.py
+12
-7
tool_rate.py
src/agent/tool_rate.py
+27
-83
api.py
src/controller/api.py
+19
-7
server.py
src/mock/server.py
+5
-5
agent_rate.py
src/server/agent_rate.py
+6
-15
No files found.
src/agent/fake_data_rate.py
deleted
100644 → 0
View file @
f1fb0cb1
import
pandas
as
pd
import
numpy
as
np
from
datetime
import
datetime
,
timedelta
import
random
class
FakeDataGenerator
:
"""生成用于测试在线率分析工具的模拟数据"""
def
__init__
(
self
):
# 省份列表
self
.
provinces
=
[
'北京市'
,
'上海市'
,
'广东省'
,
'江苏省'
,
'浙江省'
,
'山东省'
,
'河南省'
,
'四川省'
,
'湖北省'
,
'福建省'
]
# 制造商列表
self
.
manufacturers
=
[
'华为'
,
'中兴'
,
'烽火'
,
'诺基亚'
,
'爱立信'
,
'思科'
,
'新华三'
,
'锐捷'
,
'迈普'
,
'东方通信'
]
# 基础在线率范围
self
.
base_rate_range
=
(
0.85
,
0.98
)
def
generate_region_data
(
self
,
region_name
:
str
,
start_time
:
str
,
end_time
:
str
)
->
pd
.
DataFrame
:
"""生成指定地区和时间段的在线率数据"""
start_date
=
datetime
.
strptime
(
start_time
,
'
%
Y-
%
m-
%
d'
)
end_date
=
datetime
.
strptime
(
end_time
,
'
%
Y-
%
m-
%
d'
)
date_range
=
pd
.
date_range
(
start_date
,
end_date
,
freq
=
'D'
)
data
=
[]
base_rate
=
random
.
uniform
(
*
self
.
base_rate_range
)
for
date
in
date_range
:
# 添加一些随机波动
daily_rate
=
min
(
1.0
,
max
(
0.0
,
base_rate
+
random
.
uniform
(
-
0.05
,
0.05
)))
data
.
append
({
'date'
:
date
,
'region'
:
region_name
,
'online_rate'
:
daily_rate
,
'device_count'
:
random
.
randint
(
1000
,
5000
)
})
return
pd
.
DataFrame
(
data
)
def
generate_ranking_data
(
self
,
rank_type
:
int
)
->
pd
.
DataFrame
:
"""生成排名数据"""
if
rank_type
==
1
:
# 省份排名
entities
=
self
.
provinces
else
:
# 厂商排名
entities
=
self
.
manufacturers
data
=
[]
for
entity
in
entities
:
base_rate
=
random
.
uniform
(
*
self
.
base_rate_range
)
data
.
append
({
'name'
:
entity
,
'online_rate'
:
base_rate
,
'device_count'
:
random
.
randint
(
5000
,
20000
),
'offline_count'
:
random
.
randint
(
100
,
1000
)
})
return
pd
.
DataFrame
(
data
)
.
sort_values
(
'online_rate'
,
ascending
=
False
)
def
generate_national_trend
(
self
,
start_time
:
str
,
end_time
:
str
)
->
pd
.
DataFrame
:
"""生成全国在线率趋势数据"""
start_date
=
datetime
.
strptime
(
start_time
,
'
%
Y-
%
m-
%
d'
)
end_date
=
datetime
.
strptime
(
end_time
,
'
%
Y-
%
m-
%
d'
)
date_range
=
pd
.
date_range
(
start_date
,
end_date
,
freq
=
'D'
)
data
=
[]
base_rate
=
random
.
uniform
(
*
self
.
base_rate_range
)
trend
=
np
.
linspace
(
-
0.02
,
0.02
,
len
(
date_range
))
# 添加轻微的趋势
for
i
,
date
in
enumerate
(
date_range
):
# 基础在线率 + 趋势 + 随机波动
daily_rate
=
min
(
1.0
,
max
(
0.0
,
base_rate
+
trend
[
i
]
+
random
.
uniform
(
-
0.02
,
0.02
)))
data
.
append
({
'date'
:
date
,
'online_rate'
:
daily_rate
,
'total_devices'
:
random
.
randint
(
50000
,
100000
),
'online_devices'
:
random
.
randint
(
40000
,
90000
)
})
return
pd
.
DataFrame
(
data
)
class
MockDBConnection
:
"""模拟数据库连接类"""
def
__init__
(
self
):
self
.
fake_data
=
FakeDataGenerator
()
def
query
(
self
,
sql
:
str
,
params
:
dict
=
None
)
->
pd
.
DataFrame
:
"""模拟SQL查询"""
# 根据SQL语句特征返回相应的模拟数据
if
'region'
in
sql
.
lower
():
return
self
.
fake_data
.
generate_region_data
(
params
.
get
(
'region_name'
,
'北京市'
),
params
.
get
(
'start_time'
,
'2024-01-01'
),
params
.
get
(
'end_time'
,
'2024-01-07'
)
)
elif
'rank'
in
sql
.
lower
():
return
self
.
fake_data
.
generate_ranking_data
(
params
.
get
(
'type'
,
1
)
)
elif
'national'
in
sql
.
lower
():
return
self
.
fake_data
.
generate_national_trend
(
params
.
get
(
'start_time'
,
'2024-01-01'
),
params
.
get
(
'end_time'
,
'2024-01-07'
)
)
else
:
return
pd
.
DataFrame
()
# 默认返回空数据框
src/agent/http_tools.py
View file @
23a22fdc
...
...
@@ -7,6 +7,11 @@ from urllib.parse import urljoin
# 泛型类型定义
T
=
TypeVar
(
'T'
)
const_base_url
=
"http://172.30.0.37:30007"
const_url_point
=
"/cigem/getMonitorPointAll"
const_url_rate
=
"/cigem/getAvgOnlineRate"
const_url_rate_ranking
=
"/cigem/getOnlineRateRank"
class
BaseResponse
(
BaseModel
,
Generic
[
T
]):
"""通用响应模型"""
type
:
int
...
...
@@ -16,7 +21,7 @@ class BaseResponse(BaseModel, Generic[T]):
class
BaseHttpClient
:
"""基础HTTP客户端"""
def
__init__
(
self
,
base_url
:
str
=
"http://localhost:5001"
):
def
__init__
(
self
,
base_url
:
str
=
const_base_url
):
self
.
base_url
=
base_url
.
rstrip
(
'/'
)
self
.
timeout
=
30.0
...
...
@@ -57,7 +62,7 @@ class MonitorClient(BaseHttpClient):
"""异步查询监测点信息"""
data
=
await
self
.
_request_async
(
"POST"
,
"/api/monitor/points"
,
const_url_point
,
json
=
{
"key"
:
key
}
)
return
BaseResponse
[
List
[
MonitorPoint
]](
**
data
)
...
...
@@ -66,7 +71,7 @@ class MonitorClient(BaseHttpClient):
"""同步查询监测点信息"""
data
=
self
.
_request_sync
(
"POST"
,
"/api/monitor/points"
,
const_url_point
,
json
=
{
"key"
:
key
}
)
return
BaseResponse
[
List
[
MonitorPoint
]](
**
data
)
...
...
@@ -78,7 +83,7 @@ class RateClient(BaseHttpClient):
"""异步查询在线率信息"""
data
=
await
self
.
_request_async
(
"POST"
,
"/api/device/rate"
,
const_url_rate
,
json
=
{
'areaCode'
:
areacode
,
'startDate'
:
startDate
,
...
...
@@ -91,7 +96,7 @@ class RateClient(BaseHttpClient):
"""同步查询在线率信息"""
data
=
self
.
_request_sync
(
"POST"
,
"/api/device/rate"
,
const_url_rate
,
json
=
{
'areaCode'
:
areacode
,
'startDate'
:
startDate
,
...
...
@@ -104,7 +109,7 @@ class RateClient(BaseHttpClient):
"""同步查询在线率排名信息"""
data
=
self
.
_request_sync
(
"POST"
,
"/api/device/rate/ranking"
,
const_url_rate_ranking
,
json
=
{
'type'
:
rank_type
}
)
return
BaseResponse
[
Dict
](
**
data
)
...
...
@@ -113,7 +118,7 @@ class RateClient(BaseHttpClient):
"""异步查询在线率排名信息"""
data
=
await
self
.
_request_async
(
"POST"
,
"/api/device/rate/ranking"
,
const_url_rate_ranking
,
json
=
{
'type'
:
rank_type
}
)
return
BaseResponse
[
Dict
](
**
data
)
...
...
src/agent/tool_rate.py
View file @
23a22fdc
...
...
@@ -7,17 +7,15 @@ from pydantic import BaseModel, Field
from
typing
import
Type
from
langchain_core.tools
import
BaseTool
from
.http_tools
import
RateClient
from
.http_tools
import
RateClient
,
const_base_url
from
.code
import
AreaCodeTool
code_tool
=
AreaCodeTool
()
class
BaseRateTool
(
BaseTool
):
"""设备在线率分析基础工具类"""
db
:
Any
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
db_connection
,
**
data
):
def
__init__
(
self
,
**
data
):
super
()
.
__init__
(
**
data
)
self
.
db
=
db_connection
def
format_response
(
self
,
data
:
Dict
[
str
,
Any
],
chart
:
go
.
Figure
)
->
Dict
[
str
,
Any
]:
"""格式化返回结果"""
...
...
@@ -75,7 +73,7 @@ class RegionRateTool(BaseRateTool):
args_schema
:
Type
[
BaseModel
]
=
RegionRateArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
base_url
:
str
=
"http://localhost:5001"
,
**
data
):
def
__init__
(
self
,
base_url
:
str
=
const_base_url
,
**
data
):
super
()
.
__init__
(
**
data
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
...
...
@@ -118,74 +116,53 @@ class RankingRateTool(BaseRateTool):
name
=
"online_rate_ranking"
description
=
"查询设备在线率的排名数据,可查询省份排名或厂商排名"
args_schema
:
Type
[
BaseModel
]
=
RankingRateArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
base_url
:
str
=
const_base_url
,
**
data
):
super
()
.
__init__
(
**
data
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
def
_run
(
self
,
rate_type
:
int
)
->
Dict
[
str
,
Any
]:
return
self
.
get_ranking_data
(
rate_type
)
def
get_ranking_data
(
self
,
ra
nk
_type
:
int
)
->
Dict
[
str
,
Any
]:
if
ra
nk
_type
==
1
:
def
get_ranking_data
(
self
,
ra
te
_type
:
int
)
->
Dict
[
str
,
Any
]:
if
ra
te
_type
==
1
:
return
self
.
_get_province_ranking
()
else
:
return
self
.
_get_manufacturer_ranking
()
def
_get_province_ranking
(
self
)
->
Dict
[
str
,
Any
]:
"""获取省份在线率排名"""
sql
=
"""rank"""
df
=
self
.
db
.
query
(
sql
,
{
'type'
:
1
})
# 生成排名图表
fig
=
px
.
bar
(
df
,
x
=
'name'
,
y
=
'online_rate'
,
title
=
'各省份设备在线率排名'
)
fig
.
update_layout
(
xaxis_title
=
'省份'
,
yaxis_title
=
'在线率'
,
yaxis_tickformat
=
'.2
%
'
)
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
1
)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
data
=
{
'rankings'
:
df
.
to_dict
(
'records'
),
'total_provinces'
:
len
(
df
),
'average_rate'
:
df
[
'online_rate'
]
.
mean
(),
'rankings'
:
df
.
resultdata
,
'total_provinces'
:
len
(
df
.
resultdata
),
'best_province'
:
{
'name'
:
df
.
iloc
[
0
][
'name'
],
'rate'
:
df
.
iloc
[
0
][
'online_r
ate'
]
'name'
:
df
.
resultdata
[
0
][
'name'
],
'rate'
:
df
.
resultdata
[
0
][
'onlineR
ate'
]
}
}
return
self
.
format_response
(
data
,
fig
)
return
data
def
_get_manufacturer_ranking
(
self
)
->
Dict
[
str
,
Any
]:
"""获取厂商在线率排名"""
sql
=
"""rank"""
df
=
self
.
db
.
query
(
sql
,
{
'type'
:
2
})
# 生成排名图表
fig
=
px
.
bar
(
df
,
x
=
'name'
,
y
=
'online_rate'
,
title
=
'各设备厂商在线率排名'
)
fig
.
update_layout
(
xaxis_title
=
'厂商'
,
yaxis_title
=
'在线率'
,
yaxis_tickformat
=
'.2
%
'
)
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
2
)
print
(
"厂商数据:"
,
df
.
resultdata
)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
data
=
{
'rankings'
:
df
.
to_dict
(
'records'
),
'total_manufacturers'
:
len
(
df
),
'average_rate'
:
df
[
'online_rate'
]
.
mean
(),
'rankings'
:
df
.
resultdata
,
'total_manufacturers'
:
len
(
df
.
resultdata
),
'best_manufacturer'
:
{
'name'
:
df
.
iloc
[
0
][
'name'
],
'rate'
:
df
.
iloc
[
0
][
'online_r
ate'
]
'name'
:
df
.
resultdata
[
0
][
'name'
],
'rate'
:
df
.
resultdata
[
0
][
'onlineR
ate'
]
}
}
return
self
.
format_response
(
data
,
fig
)
return
data
class
NationalTrendArgs
(
BaseModel
):
"""全国趋势查询参数"""
start_time
:
str
=
Field
(
...
,
description
=
"开始时间 (YYYY-MM-DD)"
)
...
...
@@ -204,45 +181,11 @@ class NationalTrendTool(BaseRateTool):
"""
接口三:获取全国在线率趋势
"""
sql
=
"""national"""
df
=
self
.
db
.
query
(
sql
,
{
'start_time'
:
start_time
,
'end_time'
:
end_time
})
# 生成趋势图表
fig
=
go
.
Figure
()
# 添加在线率曲线
fig
.
add_trace
(
go
.
Scatter
(
x
=
df
[
'date'
],
y
=
df
[
'online_rate'
],
name
=
'在线率'
,
mode
=
'lines+markers'
))
# 设置图表布局
fig
.
update_layout
(
title
=
'全国设备在线率趋势'
,
xaxis_title
=
'日期'
,
yaxis_title
=
'在线率'
,
yaxis_tickformat
=
'.2
%
'
)
# 准备数据
data
=
{
'trend_data'
:
df
.
to_dict
(
'records'
),
'statistics'
:
{
'average_rate'
:
df
[
'online_rate'
]
.
mean
(),
'max_rate'
:
df
[
'online_rate'
]
.
max
(),
'min_rate'
:
df
[
'online_rate'
]
.
min
(),
'average_devices'
:
int
(
df
[
'total_devices'
]
.
mean
()),
'average_online'
:
int
(
df
[
'online_devices'
]
.
mean
())
},
'date_range'
:
{
'start'
:
start_time
,
'end'
:
end_time
}
}
return
self
.
format_response
(
data
,
fig
)
\ No newline at end of file
src/controller/api.py
View file @
23a22fdc
import
sys
import
argparse
sys
.
path
.
append
(
'../'
)
from
fastapi
import
FastAPI
,
Header
...
...
@@ -17,12 +18,8 @@ app.add_middleware(
allow_headers
=
[
"*"
],
# 允许所有HTTP头
)
base_llm
=
ChatOpenAI
(
openai_api_key
=
'xxxxxxxxxxxxx'
,
openai_api_base
=
'http://192.168.10.14:8000/v1'
,
model_name
=
'Qwen2-7B'
,
verbose
=
True
)
global
base_llm
base_llm
=
None
@app.post
(
'/api/agent/rate'
)
def
rate
(
chat_request
:
GeoAgentRateRequest
,
token
:
str
=
Header
(
None
)):
...
...
@@ -42,4 +39,19 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
}
if
__name__
==
"__main__"
:
uvicorn
.
run
(
app
,
host
=
'0.0.0.0'
,
port
=
8088
)
# 参数解析
parser
=
argparse
.
ArgumentParser
(
description
=
"启动API服务"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8088
,
help
=
"API服务端口"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
'0.0.0.0'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--llm"
,
type
=
str
,
default
=
'Qwen2-7B'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--api_base"
,
type
=
str
,
default
=
'http://192.168.10.14:8000/v1'
,
help
=
"API服务地址"
)
args
=
parser
.
parse_args
()
base_llm
=
ChatOpenAI
(
openai_api_key
=
'xxxxxxxxxxxxx'
,
openai_api_base
=
args
.
api_base
,
model_name
=
args
.
llm
,
verbose
=
True
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
)
src/mock/server.py
View file @
23a22fdc
...
...
@@ -113,7 +113,7 @@ class QueryResponse(BaseModel):
message
:
str
=
""
resultdata
:
List
[
MonitorPoint
]
@app.post
(
"/
api/monitor/points
"
,
response_model
=
QueryResponse
)
@app.post
(
"/
cigem/getMonitorPointAll
"
,
response_model
=
QueryResponse
)
async
def
query_points
(
request
:
QueryRequest
):
"""检测点查询接口"""
print
(
f
"进入 query_points 接口, 查询监测点信息: {request.key}"
)
...
...
@@ -153,7 +153,7 @@ class DeviceRateRequest(BaseModel):
class
DeviceRateItem
(
BaseModel
):
name
:
str
rate
:
str
rate
:
float
class
DeviceRateResponse
(
BaseModel
):
type
:
int
=
1
...
...
@@ -200,12 +200,12 @@ def generate_rate_data(area_code: str) -> List[DeviceRateItem]:
result_data
.
append
(
DeviceRateItem
(
name
=
sub_area
[
"name"
],
rate
=
f
"{rate:.2f}"
rate
=
rate
))
return
result_data
@app.post
(
"/
api/device/r
ate"
,
response_model
=
DeviceRateResponse
)
@app.post
(
"/
cigem/getAvgOnlineR
ate"
,
response_model
=
DeviceRateResponse
)
async
def
query_device_rate
(
request
:
DeviceRateRequest
):
"""查询不同时间段不同地区设备在线率"""
print
(
f
"进入 query_device_rate 接口, 查询参数: {request}"
)
...
...
@@ -318,7 +318,7 @@ def generate_rate_ranking_data_by_province() -> List[RankingItem]:
result_data
.
sort
(
key
=
lambda
x
:
x
.
onlineRate
,
reverse
=
True
)
return
result_data
@app.post
(
"/
api/device/rate/ranking
"
,
response_model
=
RankingResponse
)
@app.post
(
"/
cigem/getOnlineRateRank
"
,
response_model
=
RankingResponse
)
async
def
query_device_rate_ranking
(
request
:
DeviceRateRankingRequest
):
"""
查询设备在线率排名
...
...
src/server/agent_rate.py
View file @
23a22fdc
...
...
@@ -15,7 +15,6 @@ from langchain import hub
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
,
NationalTrendTool
from
src.agent.fake_data_rate
import
MockDBConnection
from
src.agent.tool_monitor
import
MonitorPointTool
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
...
...
@@ -68,20 +67,13 @@ class RateAgent:
# 适配 structured_chat_agent 的 prompt
ONLINE_RATE_SYSTEM_PROMPT
=
"""你是一个专门处理地质监测点信息及监测设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。
你可以处理以下三类核心任务:
1. 地区在线率分析:分析指定地区(省/市/区县)在特定时间段的设备在线率
2. 在线率排名分析:分析各省份或各厂商的在线率排名情况
3. 全国趋势分析:分析全国范围内在线率随时间的变化趋势
4. 监测点信息查询:查询指定地区的监测点信息
你需要:
1. 理解用户意图,将用户问题映射到合适的分析类型
2. 确保必要参数完整,如果缺少参数则提示用户缺少参数
3. 如果参数完整,则调用相应的分析工具获取数据
4. 生成清晰的分析报告,包括数据解读和markdown 格式的数据表格
5. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议
4. 生成清晰的分析报告,包括数据解读
5. 工具返回的数据务必用 markdown 格式的数据表格进行展示
6. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议
注意事项:
- 时间格式统一使用:YYYY-MM-DD
...
...
@@ -151,11 +143,10 @@ class RateAgentV2:
def
new_rate_agent
(
llm
,
verbose
:
bool
=
False
,
**
args
):
conn
=
MockDBConnection
()
tools
=
[
RegionRateTool
(
db_connection
=
conn
),
RankingRateTool
(
db_connection
=
conn
),
NationalTrendTool
(
db_connection
=
conn
),
RegionRateTool
(),
RankingRateTool
(),
NationalTrendTool
(),
MonitorPointTool
()
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment