Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
991ebf04
Commit
991ebf04
authored
Nov 26, 2024
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
预警接口及工具
parent
d7261e1c
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
255 additions
and
17 deletions
+255
-17
http_tools.py
src/agent/http_tools.py
+16
-6
tool_rate.py
src/agent/tool_rate.py
+2
-1
tool_warn.py
src/agent/tool_warn.py
+107
-10
run_tool_picker_warn.py
test/run_tool_picker_warn.py
+130
-0
No files found.
src/agent/http_tools.py
View file @
991ebf04
...
...
@@ -3,6 +3,9 @@ from typing import TypeVar, Generic, Any, Optional, List, Dict
from
pydantic
import
BaseModel
from
urllib.parse
import
urljoin
import
time
from
..utils.logger
import
get_logger
# 泛型类型定义
T
=
TypeVar
(
'T'
)
...
...
@@ -30,7 +33,8 @@ class BaseHttpClient:
"""基础HTTP客户端"""
def
__init__
(
self
,
base_url
:
str
=
const_base_url
):
self
.
base_url
=
base_url
.
rstrip
(
'/'
)
self
.
timeout
=
30.0
self
.
timeout
=
60.0
self
.
logger
=
get_logger
(
self
.
__class__
.
__name__
)
async
def
_request_async
(
self
,
method
:
str
,
endpoint
:
str
,
**
kwargs
)
->
Any
:
"""通用异步请求方法"""
...
...
@@ -42,11 +46,17 @@ class BaseHttpClient:
def
_request_sync
(
self
,
method
:
str
,
endpoint
:
str
,
**
kwargs
)
->
Any
:
"""通用同步请求方法"""
self
.
logger
.
info
(
f
"请求URL: {urljoin(self.base_url, endpoint)},请求参数: {kwargs}"
)
start_time
=
time
.
time
()
result
=
None
with
httpx
.
Client
(
timeout
=
self
.
timeout
)
as
client
:
url
=
urljoin
(
self
.
base_url
,
endpoint
)
response
=
client
.
request
(
method
,
url
,
**
kwargs
)
response
.
raise_for_status
()
return
response
.
json
()
result
=
response
.
json
()
end_time
=
time
.
time
()
self
.
logger
.
info
(
f
"请求耗时: {end_time - start_time}秒"
)
return
result
class
MonitorPoint
(
BaseModel
):
...
...
@@ -195,14 +205,14 @@ class RateClient(BaseHttpClient):
)
return
BaseResponse
[
List
](
**
data
)
async
def
query_rates_ranking
(
self
,
rank_type
:
int
)
->
BaseResponse
[
Lis
t
]:
async
def
query_rates_ranking
(
self
,
rank_type
:
int
)
->
BaseResponse
[
Dic
t
]:
"""异步查询在线率排名信息"""
data
=
await
self
.
_request_async
(
"POST"
,
const_url_rate_ranking
,
json
=
{
'type'
:
rank_type
}
)
return
BaseResponse
[
Lis
t
](
**
data
)
return
BaseResponse
[
Dic
t
](
**
data
)
def
query_rates_month_sync
(
self
,
year
:
str
,
areaCode
:
str
,
typeArr
:
str
)
->
BaseResponse
[
List
]:
"""同步查询按月度统计的在线率信息"""
...
...
@@ -220,7 +230,7 @@ class RateClient(BaseHttpClient):
class
WarningClient
(
BaseHttpClient
):
"""预警查询客户端"""
def
query_warning_statistics
(
self
,
start_time
:
str
,
end_time
:
str
,
area_code
:
str
)
->
BaseResponse
[
Lis
t
]:
def
query_warning_statistics
(
self
,
start_time
:
str
,
end_time
:
str
,
area_code
:
str
)
->
BaseResponse
[
Dic
t
]:
"""同步查询预警统计信息"""
data
=
self
.
_request_sync
(
"POST"
,
...
...
@@ -231,7 +241,7 @@ class WarningClient(BaseHttpClient):
'areaCode'
:
area_code
}
)
return
BaseResponse
[
Lis
t
](
**
data
)
return
BaseResponse
[
Dic
t
](
**
data
)
def
query_warning_month_statistics
(
self
,
year
:
str
,
areaCode
:
str
)
->
BaseResponse
[
List
]:
"""同步查询按月度统计的预警统计信息"""
...
...
src/agent/tool_rate.py
View file @
991ebf04
...
...
@@ -156,7 +156,8 @@ class RegionRateTool(BaseRateTool):
month_data
=
self
.
_extract_rate_data
(
item
)
result_data
[
item
[
'month'
]]
=
month_data
# 排序
result_data
=
sorted
(
result_data
.
items
(),
key
=
lambda
x
:
x
[
0
])
self
.
logger
.
debug
(
f
"查询结果: {df.resultdata}"
)
# markdown = self.to_markdown(df.resultdata)
...
...
src/agent/tool_warn.py
View file @
991ebf04
...
...
@@ -7,30 +7,128 @@
from
pydantic
import
BaseModel
,
Field
from
typing
import
Any
,
Dict
,
Type
import
logging
from
langchain_core.tools
import
BaseTool
from
.http_tools
import
WarningClient
,
const_base_url
from
.code
import
AreaCodeTool
from
..utils.logger
import
get_logger
code_tool
=
AreaCodeTool
()
class
WarningArgs
(
BaseModel
):
"""预警查询参数"""
start_time
:
str
=
Field
(
...
,
description
=
"开始时间 (YYYY-MM-DD)"
)
end_time
:
str
=
Field
(
...
,
description
=
"结束时间 (YYYY-MM-DD)"
)
region_name
:
str
=
Field
(
...
,
description
=
"地区名称,如果要查询全国数据,请输入空字符串"
)
month_statistics
:
bool
=
Field
(
False
,
description
=
"是否按月度查询一年内的统计结果,默认不需要"
)
start_time
:
str
=
Field
(
""
,
description
=
"开始时间 (YYYY-MM-DD HH:mm:ss)"
)
end_time
:
str
=
Field
(
""
,
description
=
"结束时间 (YYYY-MM-DD HH:mm:ss)"
)
region_name
:
str
=
Field
(
""
,
description
=
"地区名称,如果要查询全国数据,请输入空字符串"
)
class
WarningTool
(
BaseTool
):
"""查询预警处置和虚警情况"""
name
:
str
=
"warning_statistics"
description
:
str
=
"查询
不同时间段不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持按月度统计一年内的虚警率,处置率
。"
description
:
str
=
"查询
一定时间范围内不同地区预警处置和虚警情况,包括处置率、虚警率、蓝黄橙红数数量和占比统计。也支持查询指定年份全年按月度统计的虚警率,处置率等信息
。"
args_schema
:
Type
[
BaseModel
]
=
WarningArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
logger
:
logging
.
Logger
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
base_url
:
str
=
const_base_url
,
**
data
):
super
()
.
__init__
(
**
data
)
self
.
client
=
WarningClient
(
base_url
=
base_url
)
self
.
logger
=
get_logger
(
"WarningTool"
)
self
.
logger
.
info
(
f
"初始化 WarningTool,base_url: {base_url}"
)
def
_run
(
self
,
start_time
:
str
,
end_time
:
str
,
region_name
:
str
=
""
,
month_statistics
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
return
self
.
get_warning_statistics
(
start_time
,
end_time
,
region_name
,
month_statistics
)
def
_run
(
self
,
start_time
:
str
,
end_time
:
str
,
region_name
:
str
=
""
)
->
Dict
[
str
,
Any
]:
code
=
""
if
region_name
!=
""
:
self
.
logger
.
debug
(
f
"查找区域代码: {region_name}"
)
codes
=
code_tool
.
find_code
(
region_name
)
if
codes
is
None
or
len
(
codes
)
==
0
:
error_msg
=
f
'未找到匹配的区域代码: {region_name}'
self
.
logger
.
warning
(
error_msg
)
return
{
'code'
:
400
,
'message'
:
error_msg
}
code
=
codes
[
0
][
1
]
self
.
logger
.
debug
(
f
"找到区域代码: {code}"
)
year
=
start_time
.
split
(
"-"
)[
0
]
detail
=
self
.
get_warning_statistics
(
start_time
,
end_time
,
region_name
,
code
)
monthly
=
self
.
get_warning_statistics_of_month
(
year
,
region_name
,
code
)
if
monthly
[
'code'
]
!=
200
:
return
monthly
return
{
"整体详细数据"
:
detail
[
'data'
]
if
detail
[
'code'
]
==
200
else
detail
[
'message'
],
"各月数据"
:
monthly
[
'data'
]
if
monthly
[
'code'
]
==
200
else
monthly
[
'message'
]
}
def
get_warning_statistics
(
self
,
start_time
:
str
,
end_time
:
str
,
region_name
:
str
=
""
,
month_statistics
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
pass
\ No newline at end of file
def
get_warning_statistics
(
self
,
start_time
:
str
,
end_time
:
str
,
region_name
:
str
=
""
,
code
:
str
=
""
)
->
Dict
[
str
,
Any
]:
try
:
response
=
self
.
client
.
query_warning_statistics
(
start_time
,
end_time
,
code
)
self
.
logger
.
debug
(
f
"API响应: {response}"
)
if
response
.
type
!=
1
or
len
(
response
.
resultdata
)
==
0
:
error_msg
=
f
"查询失败: {response.message},请检查是否有相关数据权限"
self
.
logger
.
warning
(
error_msg
)
return
{
'code'
:
400
,
'message'
:
error_msg
}
data
=
{
"预警消息个数"
:
response
.
resultdata
[
"num"
],
"处置消息个数"
:
response
.
resultdata
[
"closenum"
],
"处置率"
:
response
.
resultdata
[
"closeper"
],
"虚警消息个数"
:
response
.
resultdata
[
"falsenum"
],
"虚警率"
:
response
.
resultdata
[
"falseper"
],
"红色预警消息个数"
:
response
.
resultdata
[
"rednum"
],
"红色处置消息个数"
:
response
.
resultdata
[
"redcloseper"
],
"橙色预警消息个数"
:
response
.
resultdata
[
"orangenum"
],
"橙色处置消息个数"
:
response
.
resultdata
[
"orangecloseper"
],
"黄色预警消息个数"
:
response
.
resultdata
[
"yellownum"
],
"黄色处置消息个数"
:
response
.
resultdata
[
"yellowcloseper"
],
"蓝色预警消息个数"
:
response
.
resultdata
[
"bluenum"
],
"蓝色处置消息个数"
:
response
.
resultdata
[
"bluecloseper"
],
"数据异常消息个数"
:
response
.
resultdata
[
"datanum"
],
"数据异常消息占比"
:
response
.
resultdata
[
"datacloseper"
],
"设备维护"
:
response
.
resultdata
[
"devicemainnum"
],
"设备维护占比"
:
response
.
resultdata
[
"devicemaincloseper"
],
"设备遭到破坏"
:
response
.
resultdata
[
"damagenum"
],
"设备遭到破坏占比"
:
response
.
resultdata
[
"damagecloseper"
],
"模型待优化"
:
response
.
resultdata
[
"modelnum"
],
"模型待优化占比"
:
response
.
resultdata
[
"modelcloseper"
],
}
return
{
'code'
:
200
,
'data'
:
data
}
except
Exception
as
e
:
self
.
logger
.
error
(
f
"查询预警统计信息失败: {e}"
)
return
{
'code'
:
400
,
'message'
:
str
(
e
)}
def
get_warning_statistics_of_month
(
self
,
year
:
str
,
region_name
:
str
=
""
,
code
:
str
=
""
)
->
Dict
[
str
,
Any
]:
try
:
response
=
self
.
client
.
query_warning_month_statistics
(
year
,
code
)
self
.
logger
.
debug
(
f
"API响应: {response}"
)
if
response
.
type
!=
1
or
len
(
response
.
resultdata
)
==
0
:
error_msg
=
f
"查询失败: {response.message},请检查是否有相关数据权限"
self
.
logger
.
warning
(
error_msg
)
return
{
'code'
:
400
,
'message'
:
error_msg
}
data
=
{}
for
item
in
response
.
resultdata
:
month
=
{
"预警数量"
:
item
[
"num"
],
"处置率"
:
item
[
"closeper"
],
"处置数量"
:
item
[
"closenum"
],
"虚警率"
:
item
[
"falseper"
],
"虚警数量"
:
item
[
"falsenum"
],
}
data
[
item
[
"month"
]]
=
month
data
=
sorted
(
data
.
items
(),
key
=
lambda
x
:
x
[
0
])
return
{
'code'
:
200
,
'data'
:
data
}
except
Exception
as
e
:
self
.
logger
.
error
(
f
"查询预警统计信息失败: {e}"
)
return
{
'code'
:
400
,
'message'
:
str
(
e
)}
test/run_tool_picker_warn.py
0 → 100644
View file @
991ebf04
from
langchain_openai
import
ChatOpenAI
from
rich.console
import
Console
from
rich.table
import
Table
import
sys
,
os
sys
.
path
.
append
(
"../"
)
from
src.server.tool_picker
import
ToolPicker
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_monitor
import
MonitorPointTool
from
src.agent.tool_warn
import
WarningTool
def
run_examples
():
# 初始化 rich console
console
=
Console
()
# 初始化 LLM
llm
=
ChatOpenAI
(
openai_api_key
=
"xxxxxx"
,
openai_api_base
=
"http://192.168.10.14:8000/v1"
,
model_name
=
"Qwen2-7B"
,
verbose
=
True
)
base_url
=
"http://172.30.0.37:30007"
# 初始化工具
tools
=
[
RegionRateTool
(
base_url
=
base_url
),
RankingRateTool
(
base_url
=
base_url
),
MonitorPointTool
(
base_url
=
base_url
),
WarningTool
(
base_url
=
base_url
),
]
tool_dict
=
{
tool
.
name
:
tool
for
tool
in
tools
}
# 初始化 ToolPicker
picker
=
ToolPicker
(
llm
,
tools
)
# 测试案例和预期结果
test_cases
=
[
{
"query"
:
"查询2024年4月到5月甘肃省的预警情况"
,
"expected"
:
{
"tool"
:
"warning_statistics"
,
"params"
:
{
"start_time"
:
"2024-04-01 00:00:00"
,
"end_time"
:
"2024-05-31 23:59:59"
,
"region_name"
:
"甘肃省"
,
}
}
},{
"query"
:
"查询2024年4月到5月甘肃省的处置率是多少"
,
"expected"
:
{
"tool"
:
"warning_statistics"
,
"params"
:
{
"start_time"
:
"2024-04-01 00:00:00"
,
"end_time"
:
"2024-05-31 23:59:59"
,
"region_name"
:
"甘肃省"
,
}
}
},{
"query"
:
"查询2024年甘肃省各月的预警情况总体分析"
,
"expected"
:
{
"tool"
:
"warning_statistics"
,
"params"
:
{
"start_time"
:
"2024-01-01 00:00:00"
,
"end_time"
:
"2024-12-31 23:59:59"
,
"region_name"
:
"甘肃省"
,
}
}
},{
"query"
:
"查询2024年甘肃省上半年的虚警率是多少"
,
"expected"
:
{
"tool"
:
"warning_statistics"
,
"params"
:
{
"start_time"
:
"2024-01-01 00:00:00"
,
"end_time"
:
"2024-06-30 23:59:59"
,
"region_name"
:
"甘肃省"
,
}
}
}
]
# 为每个测试案例创建一个表格
for
i
,
case
in
enumerate
(
test_cases
,
1
):
console
.
print
(
f
"
\n
[bold cyan]=== 测试案例 {i} ===[/bold cyan]"
)
table
=
Table
(
title
=
f
"查询: {case['query']}"
)
table
.
add_column
(
"项目"
,
style
=
"cyan"
)
table
.
add_column
(
"预期结果"
,
style
=
"green"
)
table
.
add_column
(
"实际结果"
,
style
=
"yellow"
)
table
.
add_column
(
"是否匹配"
,
style
=
"magenta"
)
try
:
result
=
picker
.
pick
(
case
[
"query"
])
# 添加工具比较行
expected_tool
=
case
[
"expected"
][
"tool"
]
actual_tool
=
result
[
"tool"
]
table
.
add_row
(
"选择的工具"
,
expected_tool
,
actual_tool
,
"✓"
if
expected_tool
==
actual_tool
else
"✗"
)
# 添加参数比较行
for
param_key
in
case
[
"expected"
][
"params"
]:
expected_value
=
str
(
case
[
"expected"
][
"params"
][
param_key
])
actual_value
=
str
(
result
[
"params"
]
.
get
(
param_key
,
"未提供"
))
table
.
add_row
(
f
"参数: {param_key}"
,
expected_value
,
actual_value
,
"✓"
if
expected_value
==
actual_value
else
"✗"
)
tool
=
tool_dict
[
result
[
"tool"
]]
params
=
result
[
"params"
]
result
=
tool
.
invoke
(
params
)
print
(
result
)
except
Exception
as
e
:
table
.
add_row
(
"错误"
,
""
,
str
(
e
),
"✗"
)
console
.
print
(
table
)
console
.
print
(
"="
*
80
)
if
__name__
==
"__main__"
:
run_examples
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment