Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
23a22fdc
Commit
23a22fdc
authored
9 months ago
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
接口根据实际修改
parent
f1fb0cb1
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
229 deletions
+69
-229
fake_data_rate.py
src/agent/fake_data_rate.py
+0
-112
http_tools.py
src/agent/http_tools.py
+12
-7
tool_rate.py
src/agent/tool_rate.py
+27
-83
api.py
src/controller/api.py
+19
-7
server.py
src/mock/server.py
+5
-5
agent_rate.py
src/server/agent_rate.py
+6
-15
No files found.
src/agent/fake_data_rate.py
deleted
100644 → 0
View file @
f1fb0cb1
import
pandas
as
pd
import
numpy
as
np
from
datetime
import
datetime
,
timedelta
import
random
class
FakeDataGenerator
:
"""生成用于测试在线率分析工具的模拟数据"""
def
__init__
(
self
):
# 省份列表
self
.
provinces
=
[
'北京市'
,
'上海市'
,
'广东省'
,
'江苏省'
,
'浙江省'
,
'山东省'
,
'河南省'
,
'四川省'
,
'湖北省'
,
'福建省'
]
# 制造商列表
self
.
manufacturers
=
[
'华为'
,
'中兴'
,
'烽火'
,
'诺基亚'
,
'爱立信'
,
'思科'
,
'新华三'
,
'锐捷'
,
'迈普'
,
'东方通信'
]
# 基础在线率范围
self
.
base_rate_range
=
(
0.85
,
0.98
)
def
generate_region_data
(
self
,
region_name
:
str
,
start_time
:
str
,
end_time
:
str
)
->
pd
.
DataFrame
:
"""生成指定地区和时间段的在线率数据"""
start_date
=
datetime
.
strptime
(
start_time
,
'
%
Y-
%
m-
%
d'
)
end_date
=
datetime
.
strptime
(
end_time
,
'
%
Y-
%
m-
%
d'
)
date_range
=
pd
.
date_range
(
start_date
,
end_date
,
freq
=
'D'
)
data
=
[]
base_rate
=
random
.
uniform
(
*
self
.
base_rate_range
)
for
date
in
date_range
:
# 添加一些随机波动
daily_rate
=
min
(
1.0
,
max
(
0.0
,
base_rate
+
random
.
uniform
(
-
0.05
,
0.05
)))
data
.
append
({
'date'
:
date
,
'region'
:
region_name
,
'online_rate'
:
daily_rate
,
'device_count'
:
random
.
randint
(
1000
,
5000
)
})
return
pd
.
DataFrame
(
data
)
def
generate_ranking_data
(
self
,
rank_type
:
int
)
->
pd
.
DataFrame
:
"""生成排名数据"""
if
rank_type
==
1
:
# 省份排名
entities
=
self
.
provinces
else
:
# 厂商排名
entities
=
self
.
manufacturers
data
=
[]
for
entity
in
entities
:
base_rate
=
random
.
uniform
(
*
self
.
base_rate_range
)
data
.
append
({
'name'
:
entity
,
'online_rate'
:
base_rate
,
'device_count'
:
random
.
randint
(
5000
,
20000
),
'offline_count'
:
random
.
randint
(
100
,
1000
)
})
return
pd
.
DataFrame
(
data
)
.
sort_values
(
'online_rate'
,
ascending
=
False
)
def
generate_national_trend
(
self
,
start_time
:
str
,
end_time
:
str
)
->
pd
.
DataFrame
:
"""生成全国在线率趋势数据"""
start_date
=
datetime
.
strptime
(
start_time
,
'
%
Y-
%
m-
%
d'
)
end_date
=
datetime
.
strptime
(
end_time
,
'
%
Y-
%
m-
%
d'
)
date_range
=
pd
.
date_range
(
start_date
,
end_date
,
freq
=
'D'
)
data
=
[]
base_rate
=
random
.
uniform
(
*
self
.
base_rate_range
)
trend
=
np
.
linspace
(
-
0.02
,
0.02
,
len
(
date_range
))
# 添加轻微的趋势
for
i
,
date
in
enumerate
(
date_range
):
# 基础在线率 + 趋势 + 随机波动
daily_rate
=
min
(
1.0
,
max
(
0.0
,
base_rate
+
trend
[
i
]
+
random
.
uniform
(
-
0.02
,
0.02
)))
data
.
append
({
'date'
:
date
,
'online_rate'
:
daily_rate
,
'total_devices'
:
random
.
randint
(
50000
,
100000
),
'online_devices'
:
random
.
randint
(
40000
,
90000
)
})
return
pd
.
DataFrame
(
data
)
class
MockDBConnection
:
"""模拟数据库连接类"""
def
__init__
(
self
):
self
.
fake_data
=
FakeDataGenerator
()
def
query
(
self
,
sql
:
str
,
params
:
dict
=
None
)
->
pd
.
DataFrame
:
"""模拟SQL查询"""
# 根据SQL语句特征返回相应的模拟数据
if
'region'
in
sql
.
lower
():
return
self
.
fake_data
.
generate_region_data
(
params
.
get
(
'region_name'
,
'北京市'
),
params
.
get
(
'start_time'
,
'2024-01-01'
),
params
.
get
(
'end_time'
,
'2024-01-07'
)
)
elif
'rank'
in
sql
.
lower
():
return
self
.
fake_data
.
generate_ranking_data
(
params
.
get
(
'type'
,
1
)
)
elif
'national'
in
sql
.
lower
():
return
self
.
fake_data
.
generate_national_trend
(
params
.
get
(
'start_time'
,
'2024-01-01'
),
params
.
get
(
'end_time'
,
'2024-01-07'
)
)
else
:
return
pd
.
DataFrame
()
# 默认返回空数据框
This diff is collapsed.
Click to expand it.
src/agent/http_tools.py
View file @
23a22fdc
...
@@ -7,6 +7,11 @@ from urllib.parse import urljoin
...
@@ -7,6 +7,11 @@ from urllib.parse import urljoin
# 泛型类型定义
# 泛型类型定义
T
=
TypeVar
(
'T'
)
T
=
TypeVar
(
'T'
)
const_base_url
=
"http://172.30.0.37:30007"
const_url_point
=
"/cigem/getMonitorPointAll"
const_url_rate
=
"/cigem/getAvgOnlineRate"
const_url_rate_ranking
=
"/cigem/getOnlineRateRank"
class
BaseResponse
(
BaseModel
,
Generic
[
T
]):
class
BaseResponse
(
BaseModel
,
Generic
[
T
]):
"""通用响应模型"""
"""通用响应模型"""
type
:
int
type
:
int
...
@@ -16,7 +21,7 @@ class BaseResponse(BaseModel, Generic[T]):
...
@@ -16,7 +21,7 @@ class BaseResponse(BaseModel, Generic[T]):
class
BaseHttpClient
:
class
BaseHttpClient
:
"""基础HTTP客户端"""
"""基础HTTP客户端"""
def
__init__
(
self
,
base_url
:
str
=
"http://localhost:5001"
):
def
__init__
(
self
,
base_url
:
str
=
const_base_url
):
self
.
base_url
=
base_url
.
rstrip
(
'/'
)
self
.
base_url
=
base_url
.
rstrip
(
'/'
)
self
.
timeout
=
30.0
self
.
timeout
=
30.0
...
@@ -57,7 +62,7 @@ class MonitorClient(BaseHttpClient):
...
@@ -57,7 +62,7 @@ class MonitorClient(BaseHttpClient):
"""异步查询监测点信息"""
"""异步查询监测点信息"""
data
=
await
self
.
_request_async
(
data
=
await
self
.
_request_async
(
"POST"
,
"POST"
,
"/api/monitor/points"
,
const_url_point
,
json
=
{
"key"
:
key
}
json
=
{
"key"
:
key
}
)
)
return
BaseResponse
[
List
[
MonitorPoint
]](
**
data
)
return
BaseResponse
[
List
[
MonitorPoint
]](
**
data
)
...
@@ -66,7 +71,7 @@ class MonitorClient(BaseHttpClient):
...
@@ -66,7 +71,7 @@ class MonitorClient(BaseHttpClient):
"""同步查询监测点信息"""
"""同步查询监测点信息"""
data
=
self
.
_request_sync
(
data
=
self
.
_request_sync
(
"POST"
,
"POST"
,
"/api/monitor/points"
,
const_url_point
,
json
=
{
"key"
:
key
}
json
=
{
"key"
:
key
}
)
)
return
BaseResponse
[
List
[
MonitorPoint
]](
**
data
)
return
BaseResponse
[
List
[
MonitorPoint
]](
**
data
)
...
@@ -78,7 +83,7 @@ class RateClient(BaseHttpClient):
...
@@ -78,7 +83,7 @@ class RateClient(BaseHttpClient):
"""异步查询在线率信息"""
"""异步查询在线率信息"""
data
=
await
self
.
_request_async
(
data
=
await
self
.
_request_async
(
"POST"
,
"POST"
,
"/api/device/rate"
,
const_url_rate
,
json
=
{
json
=
{
'areaCode'
:
areacode
,
'areaCode'
:
areacode
,
'startDate'
:
startDate
,
'startDate'
:
startDate
,
...
@@ -91,7 +96,7 @@ class RateClient(BaseHttpClient):
...
@@ -91,7 +96,7 @@ class RateClient(BaseHttpClient):
"""同步查询在线率信息"""
"""同步查询在线率信息"""
data
=
self
.
_request_sync
(
data
=
self
.
_request_sync
(
"POST"
,
"POST"
,
"/api/device/rate"
,
const_url_rate
,
json
=
{
json
=
{
'areaCode'
:
areacode
,
'areaCode'
:
areacode
,
'startDate'
:
startDate
,
'startDate'
:
startDate
,
...
@@ -104,7 +109,7 @@ class RateClient(BaseHttpClient):
...
@@ -104,7 +109,7 @@ class RateClient(BaseHttpClient):
"""同步查询在线率排名信息"""
"""同步查询在线率排名信息"""
data
=
self
.
_request_sync
(
data
=
self
.
_request_sync
(
"POST"
,
"POST"
,
"/api/device/rate/ranking"
,
const_url_rate_ranking
,
json
=
{
'type'
:
rank_type
}
json
=
{
'type'
:
rank_type
}
)
)
return
BaseResponse
[
Dict
](
**
data
)
return
BaseResponse
[
Dict
](
**
data
)
...
@@ -113,7 +118,7 @@ class RateClient(BaseHttpClient):
...
@@ -113,7 +118,7 @@ class RateClient(BaseHttpClient):
"""异步查询在线率排名信息"""
"""异步查询在线率排名信息"""
data
=
await
self
.
_request_async
(
data
=
await
self
.
_request_async
(
"POST"
,
"POST"
,
"/api/device/rate/ranking"
,
const_url_rate_ranking
,
json
=
{
'type'
:
rank_type
}
json
=
{
'type'
:
rank_type
}
)
)
return
BaseResponse
[
Dict
](
**
data
)
return
BaseResponse
[
Dict
](
**
data
)
...
...
This diff is collapsed.
Click to expand it.
src/agent/tool_rate.py
View file @
23a22fdc
...
@@ -7,17 +7,15 @@ from pydantic import BaseModel, Field
...
@@ -7,17 +7,15 @@ from pydantic import BaseModel, Field
from
typing
import
Type
from
typing
import
Type
from
langchain_core.tools
import
BaseTool
from
langchain_core.tools
import
BaseTool
from
.http_tools
import
RateClient
from
.http_tools
import
RateClient
,
const_base_url
from
.code
import
AreaCodeTool
from
.code
import
AreaCodeTool
code_tool
=
AreaCodeTool
()
code_tool
=
AreaCodeTool
()
class
BaseRateTool
(
BaseTool
):
class
BaseRateTool
(
BaseTool
):
"""设备在线率分析基础工具类"""
"""设备在线率分析基础工具类"""
db
:
Any
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
db_connection
,
**
data
):
def
__init__
(
self
,
**
data
):
super
()
.
__init__
(
**
data
)
super
()
.
__init__
(
**
data
)
self
.
db
=
db_connection
def
format_response
(
self
,
data
:
Dict
[
str
,
Any
],
chart
:
go
.
Figure
)
->
Dict
[
str
,
Any
]:
def
format_response
(
self
,
data
:
Dict
[
str
,
Any
],
chart
:
go
.
Figure
)
->
Dict
[
str
,
Any
]:
"""格式化返回结果"""
"""格式化返回结果"""
...
@@ -75,7 +73,7 @@ class RegionRateTool(BaseRateTool):
...
@@ -75,7 +73,7 @@ class RegionRateTool(BaseRateTool):
args_schema
:
Type
[
BaseModel
]
=
RegionRateArgs
args_schema
:
Type
[
BaseModel
]
=
RegionRateArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
client
:
Any
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
base_url
:
str
=
"http://localhost:5001"
,
**
data
):
def
__init__
(
self
,
base_url
:
str
=
const_base_url
,
**
data
):
super
()
.
__init__
(
**
data
)
super
()
.
__init__
(
**
data
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
...
@@ -118,74 +116,53 @@ class RankingRateTool(BaseRateTool):
...
@@ -118,74 +116,53 @@ class RankingRateTool(BaseRateTool):
name
=
"online_rate_ranking"
name
=
"online_rate_ranking"
description
=
"查询设备在线率的排名数据,可查询省份排名或厂商排名"
description
=
"查询设备在线率的排名数据,可查询省份排名或厂商排名"
args_schema
:
Type
[
BaseModel
]
=
RankingRateArgs
args_schema
:
Type
[
BaseModel
]
=
RankingRateArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
base_url
:
str
=
const_base_url
,
**
data
):
super
()
.
__init__
(
**
data
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
def
_run
(
self
,
rate_type
:
int
)
->
Dict
[
str
,
Any
]:
def
_run
(
self
,
rate_type
:
int
)
->
Dict
[
str
,
Any
]:
return
self
.
get_ranking_data
(
rate_type
)
return
self
.
get_ranking_data
(
rate_type
)
def
get_ranking_data
(
self
,
ra
nk
_type
:
int
)
->
Dict
[
str
,
Any
]:
def
get_ranking_data
(
self
,
ra
te
_type
:
int
)
->
Dict
[
str
,
Any
]:
if
ra
nk
_type
==
1
:
if
ra
te
_type
==
1
:
return
self
.
_get_province_ranking
()
return
self
.
_get_province_ranking
()
else
:
else
:
return
self
.
_get_manufacturer_ranking
()
return
self
.
_get_manufacturer_ranking
()
def
_get_province_ranking
(
self
)
->
Dict
[
str
,
Any
]:
def
_get_province_ranking
(
self
)
->
Dict
[
str
,
Any
]:
"""获取省份在线率排名"""
"""获取省份在线率排名"""
sql
=
"""rank"""
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
1
)
df
=
self
.
db
.
query
(
sql
,
{
'type'
:
1
})
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 生成排名图表
fig
=
px
.
bar
(
df
,
x
=
'name'
,
y
=
'online_rate'
,
title
=
'各省份设备在线率排名'
)
fig
.
update_layout
(
xaxis_title
=
'省份'
,
yaxis_title
=
'在线率'
,
yaxis_tickformat
=
'.2
%
'
)
# 准备数据
# 准备数据
data
=
{
data
=
{
'rankings'
:
df
.
to_dict
(
'records'
),
'rankings'
:
df
.
resultdata
,
'total_provinces'
:
len
(
df
),
'total_provinces'
:
len
(
df
.
resultdata
),
'average_rate'
:
df
[
'online_rate'
]
.
mean
(),
'best_province'
:
{
'best_province'
:
{
'name'
:
df
.
iloc
[
0
][
'name'
],
'name'
:
df
.
resultdata
[
0
][
'name'
],
'rate'
:
df
.
iloc
[
0
][
'online_r
ate'
]
'rate'
:
df
.
resultdata
[
0
][
'onlineR
ate'
]
}
}
}
}
return
self
.
format_response
(
data
,
fig
)
return
data
def
_get_manufacturer_ranking
(
self
)
->
Dict
[
str
,
Any
]:
def
_get_manufacturer_ranking
(
self
)
->
Dict
[
str
,
Any
]:
"""获取厂商在线率排名"""
"""获取厂商在线率排名"""
sql
=
"""rank"""
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
2
)
df
=
self
.
db
.
query
(
sql
,
{
'type'
:
2
})
print
(
"厂商数据:"
,
df
.
resultdata
)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 生成排名图表
fig
=
px
.
bar
(
df
,
x
=
'name'
,
y
=
'online_rate'
,
title
=
'各设备厂商在线率排名'
)
fig
.
update_layout
(
xaxis_title
=
'厂商'
,
yaxis_title
=
'在线率'
,
yaxis_tickformat
=
'.2
%
'
)
# 准备数据
# 准备数据
data
=
{
data
=
{
'rankings'
:
df
.
to_dict
(
'records'
),
'rankings'
:
df
.
resultdata
,
'total_manufacturers'
:
len
(
df
),
'total_manufacturers'
:
len
(
df
.
resultdata
),
'average_rate'
:
df
[
'online_rate'
]
.
mean
(),
'best_manufacturer'
:
{
'best_manufacturer'
:
{
'name'
:
df
.
iloc
[
0
][
'name'
],
'name'
:
df
.
resultdata
[
0
][
'name'
],
'rate'
:
df
.
iloc
[
0
][
'online_r
ate'
]
'rate'
:
df
.
resultdata
[
0
][
'onlineR
ate'
]
}
}
}
}
return
self
.
format_response
(
data
,
fig
)
return
data
class
NationalTrendArgs
(
BaseModel
):
class
NationalTrendArgs
(
BaseModel
):
"""全国趋势查询参数"""
"""全国趋势查询参数"""
start_time
:
str
=
Field
(
...
,
description
=
"开始时间 (YYYY-MM-DD)"
)
start_time
:
str
=
Field
(
...
,
description
=
"开始时间 (YYYY-MM-DD)"
)
...
@@ -204,45 +181,11 @@ class NationalTrendTool(BaseRateTool):
...
@@ -204,45 +181,11 @@ class NationalTrendTool(BaseRateTool):
"""
"""
接口三:获取全国在线率趋势
接口三:获取全国在线率趋势
"""
"""
sql
=
"""national"""
df
=
self
.
db
.
query
(
sql
,
{
'start_time'
:
start_time
,
'end_time'
:
end_time
})
# 生成趋势图表
fig
=
go
.
Figure
()
# 添加在线率曲线
fig
.
add_trace
(
go
.
Scatter
(
x
=
df
[
'date'
],
y
=
df
[
'online_rate'
],
name
=
'在线率'
,
mode
=
'lines+markers'
))
# 设置图表布局
fig
.
update_layout
(
title
=
'全国设备在线率趋势'
,
xaxis_title
=
'日期'
,
yaxis_title
=
'在线率'
,
yaxis_tickformat
=
'.2
%
'
)
# 准备数据
# 准备数据
data
=
{
data
=
{
'trend_data'
:
df
.
to_dict
(
'records'
),
'statistics'
:
{
'average_rate'
:
df
[
'online_rate'
]
.
mean
(),
'max_rate'
:
df
[
'online_rate'
]
.
max
(),
'min_rate'
:
df
[
'online_rate'
]
.
min
(),
'average_devices'
:
int
(
df
[
'total_devices'
]
.
mean
()),
'average_online'
:
int
(
df
[
'online_devices'
]
.
mean
())
},
'date_range'
:
{
'start'
:
start_time
,
'end'
:
end_time
}
}
}
return
self
.
format_response
(
data
,
fig
)
return
self
.
format_response
(
data
,
fig
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
src/controller/api.py
View file @
23a22fdc
import
sys
import
sys
import
argparse
sys
.
path
.
append
(
'../'
)
sys
.
path
.
append
(
'../'
)
from
fastapi
import
FastAPI
,
Header
from
fastapi
import
FastAPI
,
Header
...
@@ -17,12 +18,8 @@ app.add_middleware(
...
@@ -17,12 +18,8 @@ app.add_middleware(
allow_headers
=
[
"*"
],
# 允许所有HTTP头
allow_headers
=
[
"*"
],
# 允许所有HTTP头
)
)
base_llm
=
ChatOpenAI
(
global
base_llm
openai_api_key
=
'xxxxxxxxxxxxx'
,
base_llm
=
None
openai_api_base
=
'http://192.168.10.14:8000/v1'
,
model_name
=
'Qwen2-7B'
,
verbose
=
True
)
@app.post
(
'/api/agent/rate'
)
@app.post
(
'/api/agent/rate'
)
def
rate
(
chat_request
:
GeoAgentRateRequest
,
token
:
str
=
Header
(
None
)):
def
rate
(
chat_request
:
GeoAgentRateRequest
,
token
:
str
=
Header
(
None
)):
...
@@ -42,4 +39,19 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
...
@@ -42,4 +39,19 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
}
}
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
uvicorn
.
run
(
app
,
host
=
'0.0.0.0'
,
port
=
8088
)
# 参数解析
parser
=
argparse
.
ArgumentParser
(
description
=
"启动API服务"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8088
,
help
=
"API服务端口"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
'0.0.0.0'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--llm"
,
type
=
str
,
default
=
'Qwen2-7B'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--api_base"
,
type
=
str
,
default
=
'http://192.168.10.14:8000/v1'
,
help
=
"API服务地址"
)
args
=
parser
.
parse_args
()
base_llm
=
ChatOpenAI
(
openai_api_key
=
'xxxxxxxxxxxxx'
,
openai_api_base
=
args
.
api_base
,
model_name
=
args
.
llm
,
verbose
=
True
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
)
This diff is collapsed.
Click to expand it.
src/mock/server.py
View file @
23a22fdc
...
@@ -113,7 +113,7 @@ class QueryResponse(BaseModel):
...
@@ -113,7 +113,7 @@ class QueryResponse(BaseModel):
message
:
str
=
""
message
:
str
=
""
resultdata
:
List
[
MonitorPoint
]
resultdata
:
List
[
MonitorPoint
]
@app.post
(
"/
api/monitor/points
"
,
response_model
=
QueryResponse
)
@app.post
(
"/
cigem/getMonitorPointAll
"
,
response_model
=
QueryResponse
)
async
def
query_points
(
request
:
QueryRequest
):
async
def
query_points
(
request
:
QueryRequest
):
"""检测点查询接口"""
"""检测点查询接口"""
print
(
f
"进入 query_points 接口, 查询监测点信息: {request.key}"
)
print
(
f
"进入 query_points 接口, 查询监测点信息: {request.key}"
)
...
@@ -153,7 +153,7 @@ class DeviceRateRequest(BaseModel):
...
@@ -153,7 +153,7 @@ class DeviceRateRequest(BaseModel):
class
DeviceRateItem
(
BaseModel
):
class
DeviceRateItem
(
BaseModel
):
name
:
str
name
:
str
rate
:
str
rate
:
float
class
DeviceRateResponse
(
BaseModel
):
class
DeviceRateResponse
(
BaseModel
):
type
:
int
=
1
type
:
int
=
1
...
@@ -200,12 +200,12 @@ def generate_rate_data(area_code: str) -> List[DeviceRateItem]:
...
@@ -200,12 +200,12 @@ def generate_rate_data(area_code: str) -> List[DeviceRateItem]:
result_data
.
append
(
DeviceRateItem
(
result_data
.
append
(
DeviceRateItem
(
name
=
sub_area
[
"name"
],
name
=
sub_area
[
"name"
],
rate
=
f
"{rate:.2f}"
rate
=
rate
))
))
return
result_data
return
result_data
@app.post
(
"/
api/device/r
ate"
,
response_model
=
DeviceRateResponse
)
@app.post
(
"/
cigem/getAvgOnlineR
ate"
,
response_model
=
DeviceRateResponse
)
async
def
query_device_rate
(
request
:
DeviceRateRequest
):
async
def
query_device_rate
(
request
:
DeviceRateRequest
):
"""查询不同时间段不同地区设备在线率"""
"""查询不同时间段不同地区设备在线率"""
print
(
f
"进入 query_device_rate 接口, 查询参数: {request}"
)
print
(
f
"进入 query_device_rate 接口, 查询参数: {request}"
)
...
@@ -318,7 +318,7 @@ def generate_rate_ranking_data_by_province() -> List[RankingItem]:
...
@@ -318,7 +318,7 @@ def generate_rate_ranking_data_by_province() -> List[RankingItem]:
result_data
.
sort
(
key
=
lambda
x
:
x
.
onlineRate
,
reverse
=
True
)
result_data
.
sort
(
key
=
lambda
x
:
x
.
onlineRate
,
reverse
=
True
)
return
result_data
return
result_data
@app.post
(
"/
api/device/rate/ranking
"
,
response_model
=
RankingResponse
)
@app.post
(
"/
cigem/getOnlineRateRank
"
,
response_model
=
RankingResponse
)
async
def
query_device_rate_ranking
(
request
:
DeviceRateRankingRequest
):
async
def
query_device_rate_ranking
(
request
:
DeviceRateRankingRequest
):
"""
"""
查询设备在线率排名
查询设备在线率排名
...
...
This diff is collapsed.
Click to expand it.
src/server/agent_rate.py
View file @
23a22fdc
...
@@ -15,7 +15,6 @@ from langchain import hub
...
@@ -15,7 +15,6 @@ from langchain import hub
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
,
NationalTrendTool
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
,
NationalTrendTool
from
src.agent.fake_data_rate
import
MockDBConnection
from
src.agent.tool_monitor
import
MonitorPointTool
from
src.agent.tool_monitor
import
MonitorPointTool
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
...
@@ -68,20 +67,13 @@ class RateAgent:
...
@@ -68,20 +67,13 @@ class RateAgent:
# 适配 structured_chat_agent 的 prompt
# 适配 structured_chat_agent 的 prompt
ONLINE_RATE_SYSTEM_PROMPT
=
"""你是一个专门处理地质监测点信息及监测设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。
ONLINE_RATE_SYSTEM_PROMPT
=
"""你是一个专门处理地质监测点信息及监测设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。
你可以处理以下三类核心任务:
1. 地区在线率分析:分析指定地区(省/市/区县)在特定时间段的设备在线率
2. 在线率排名分析:分析各省份或各厂商的在线率排名情况
3. 全国趋势分析:分析全国范围内在线率随时间的变化趋势
4. 监测点信息查询:查询指定地区的监测点信息
你需要:
你需要:
1. 理解用户意图,将用户问题映射到合适的分析类型
1. 理解用户意图,将用户问题映射到合适的分析类型
2. 确保必要参数完整,如果缺少参数则提示用户缺少参数
2. 确保必要参数完整,如果缺少参数则提示用户缺少参数
3. 如果参数完整,则调用相应的分析工具获取数据
3. 如果参数完整,则调用相应的分析工具获取数据
4. 生成清晰的分析报告,包括数据解读和markdown 格式的数据表格
4. 生成清晰的分析报告,包括数据解读
5. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议
5. 工具返回的数据务必用 markdown 格式的数据表格进行展示
6. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议
注意事项:
注意事项:
- 时间格式统一使用:YYYY-MM-DD
- 时间格式统一使用:YYYY-MM-DD
...
@@ -151,11 +143,10 @@ class RateAgentV2:
...
@@ -151,11 +143,10 @@ class RateAgentV2:
def
new_rate_agent
(
llm
,
verbose
:
bool
=
False
,
**
args
):
def
new_rate_agent
(
llm
,
verbose
:
bool
=
False
,
**
args
):
conn
=
MockDBConnection
()
tools
=
[
tools
=
[
RegionRateTool
(
db_connection
=
conn
),
RegionRateTool
(),
RankingRateTool
(
db_connection
=
conn
),
RankingRateTool
(),
NationalTrendTool
(
db_connection
=
conn
),
NationalTrendTool
(),
MonitorPointTool
()
MonitorPointTool
()
]
]
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment