Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
625c5230
Commit
625c5230
authored
8 months ago
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
优化 warn 工具描述,降低与在线率工具混淆
parent
abcc91a0
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
138 additions
and
14 deletions
+138
-14
tool_warn.py
src/agent/tool_warn.py
+2
-2
run_tool_picker_rate.py
test/run_tool_picker_rate.py
+79
-10
run_tool_picker_warn.py
test/run_tool_picker_warn.py
+27
-2
test_agent.py
test/test_agent.py
+30
-0
No files found.
src/agent/tool_warn.py
View file @
625c5230
...
...
@@ -22,12 +22,12 @@ class WarningArgs(BaseModel):
start_time
:
str
=
Field
(
""
,
description
=
"开始时间 (YYYY-MM-DD HH:mm:ss)"
)
end_time
:
str
=
Field
(
""
,
description
=
"结束时间 (YYYY-MM-DD HH:mm:ss)"
)
region_name
:
str
=
Field
(
""
,
description
=
"地区名称,如果要查询全国数据,请输入空字符串"
)
query_type
:
str
=
Field
(
"1"
,
description
=
"查询类型,1 表示查询一段时间内的综合数据,2 表示查询指定年份全年按月度统计的虚警
率,处置率
等信息"
)
query_type
:
str
=
Field
(
"1"
,
description
=
"查询类型,1 表示查询一段时间内的综合数据,2 表示查询指定年份全年按月度统计的虚警
、处置
等信息"
)
class
WarningTool
(
BaseTool
):
"""查询预警处置和虚警情况"""
name
:
str
=
"warning_statistics"
description
:
str
=
"查询一定时间范围内不同地区预警处置和虚警情况,包括处置
率、虚警率、蓝黄橙红数数量和占比统计。也支持查询指定年份全年按月度统计的虚警率,处置率
等信息。"
description
:
str
=
"查询一定时间范围内不同地区预警处置和虚警情况,包括处置
情况、虚警情况、蓝黄橙红数数量等统计。也支持查询指定年份全年按月度统计的虚警情况,处置情况
等信息。"
args_schema
:
Type
[
BaseModel
]
=
WarningArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
logger
:
logging
.
Logger
=
Field
(
None
,
exclude
=
True
)
...
...
This diff is collapsed.
Click to expand it.
test/run_tool_picker_rate.py
View file @
625c5230
from
langchain_openai
import
ChatOpenAI
from
rich.console
import
Console
from
rich.table
import
Table
import
time
import
sys
,
os
sys
.
path
.
append
(
"../"
)
from
src.server.tool_picker
import
ToolPicker
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_monitor
import
MonitorPointTool
from
src.agent.tool_warn
import
WarningTool
def
run_examples
():
# 初始化 rich console
console
=
Console
()
# 初始化 LLM
def
create_tool_picker
():
llm
=
ChatOpenAI
(
openai_api_key
=
"xxxxxx"
,
openai_api_base
=
"http://192.168.10.14:8000/v1"
,
model_name
=
"Qwen2-7B"
,
verbose
=
True
)
base_url
=
"http://172.30.0.37:30007"
# 初始化工具
tools
=
[
RegionRateTool
(
base_url
=
base_url
),
RankingRateTool
(
base_url
=
base_url
),
MonitorPointTool
(
base_url
=
base_url
),
WarningTool
(
base_url
=
base_url
),
]
return
ToolPicker
(
llm
,
tools
)
def
run_examples
():
# 初始化 rich console
console
=
Console
()
# 初始化 LLM
base_url
=
"http://172.30.0.37:30007"
# 初始化工具
tools
=
[
RegionRateTool
(
base_url
=
base_url
),
RankingRateTool
(
base_url
=
base_url
),
MonitorPointTool
(
base_url
=
base_url
),
WarningTool
(
base_url
=
base_url
),
]
tool_dict
=
{
tool
.
name
:
tool
for
tool
in
tools
}
...
...
@@ -193,9 +210,20 @@ def run_examples():
"start_time"
:
"2023-01-01"
,
"end_time"
:
"2023-12-31"
,
"region_name"
:
"西藏"
,
"month_statistics"
:
True
}
}
},
{
"query"
:
"贵阳市各县区设备在线率是多少"
,
"expected"
:
{
"tool"
:
"region_online_rate"
,
"params"
:
{
"region_name"
:
"贵阳市"
}
}
}
]
tool_selected_success
=
0
# 为每个测试案例创建一个表格
for
i
,
case
in
enumerate
(
test_cases
,
1
):
console
.
print
(
f
"
\n
[bold cyan]=== 测试案例 {i} ===[/bold cyan]"
)
...
...
@@ -219,6 +247,9 @@ def run_examples():
"✓"
if
expected_tool
==
actual_tool
else
"✗"
)
if
expected_tool
==
actual_tool
:
tool_selected_success
+=
1
# 添加参数比较行
for
param_key
in
case
[
"expected"
][
"params"
]:
expected_value
=
str
(
case
[
"expected"
][
"params"
][
param_key
])
...
...
@@ -233,8 +264,8 @@ def run_examples():
tool
=
tool_dict
[
result
[
"tool"
]]
params
=
result
[
"params"
]
result
=
tool
.
invoke
(
params
)
print
(
result
)
#
result = tool.invoke(params)
#
print(result)
except
Exception
as
e
:
table
.
add_row
(
"错误"
,
""
,
str
(
e
),
"✗"
)
...
...
@@ -242,5 +273,44 @@ def run_examples():
console
.
print
(
table
)
console
.
print
(
"="
*
80
)
print
(
f
"工具选择成功率: {tool_selected_success} / {len(test_cases)} = {tool_selected_success / len(test_cases) * 100:.2f}
%
"
)
def
run_case_17
():
console
=
Console
()
picker
=
create_tool_picker
()
success_count
=
0
total_time
=
0
test_cases_count
=
50
for
i
in
range
(
test_cases_count
):
start_time
=
time
.
time
()
query
=
"贵阳市各县区设备在线率是多少"
expected
=
{
"tool"
:
"region_online_rate"
,
"params"
:
{
"region_name"
:
"贵阳市"
}
}
result
=
picker
.
pick
(
query
)
exec_time
=
time
.
time
()
-
start_time
total_time
+=
exec_time
table
=
Table
(
title
=
f
"第 {i+1:02d} 次测试结果 【{query}】"
)
table
.
add_column
(
"查询"
,
style
=
"cyan"
)
table
.
add_column
(
"预期工具"
,
style
=
"green"
)
table
.
add_column
(
"实际工具"
,
style
=
"yellow"
)
table
.
add_column
(
"是否匹配"
,
style
=
"magenta"
)
table
.
add_column
(
"耗时"
,
style
=
"magenta"
)
table
.
add_row
(
query
,
expected
[
"tool"
],
result
[
"tool"
],
"✓"
if
result
[
"tool"
]
==
expected
[
"tool"
]
else
"✗"
,
f
"{exec_time:.2f}秒"
)
if
result
[
"tool"
]
==
expected
[
"tool"
]:
success_count
+=
1
console
.
print
(
table
)
print
(
f
"工具选择成功率: {success_count} / {test_cases_count} = {success_count / test_cases_count * 100:.2f}
%
"
)
print
(
f
"平均耗时: {total_time / test_cases_count:.2f}秒 总共耗时: {total_time:.2f}秒"
)
if
__name__
==
"__main__"
:
run_examples
()
\ No newline at end of file
# run_examples()
run_case_17
()
This diff is collapsed.
Click to expand it.
test/run_tool_picker_warn.py
View file @
625c5230
from
langchain_openai
import
ChatOpenAI
from
rich.console
import
Console
from
rich.table
import
Table
import
time
import
sys
,
os
sys
.
path
.
append
(
"../"
)
...
...
@@ -79,9 +79,23 @@ def run_examples():
"query_type"
:
"1"
,
}
}
},
{
"query"
:
"2023 年,贵阳市每月处置率是多少"
,
"expected"
:
{
"tool"
:
"warning_statistics"
,
"params"
:
{
"start_time"
:
"2023-01-01 00:00:00"
,
"end_time"
:
"2023-12-31 23:59:59"
,
"region_name"
:
"贵阳市"
,
"query_type"
:
"2"
,
}
}
}
]
success_count
=
0
total_time
=
0
# 为每个测试案例创建一个表格
for
i
,
case
in
enumerate
(
test_cases
,
1
):
console
.
print
(
f
"
\n
[bold cyan]=== 测试案例 {i} ===[/bold cyan]"
)
...
...
@@ -91,9 +105,13 @@ def run_examples():
table
.
add_column
(
"预期结果"
,
style
=
"green"
)
table
.
add_column
(
"实际结果"
,
style
=
"yellow"
)
table
.
add_column
(
"是否匹配"
,
style
=
"magenta"
)
table
.
add_column
(
"耗时"
,
style
=
"magenta"
)
try
:
start_time
=
time
.
time
()
result
=
picker
.
pick
(
case
[
"query"
])
end_time
=
time
.
time
()
exec_time
=
end_time
-
start_time
total_time
+=
exec_time
# 添加工具比较行
expected_tool
=
case
[
"expected"
][
"tool"
]
...
...
@@ -115,6 +133,10 @@ def run_examples():
actual_value
,
"✓"
if
expected_value
==
actual_value
else
"✗"
)
if
expected_tool
==
actual_tool
:
success_count
+=
1
table
.
add_row
(
"耗时"
,
""
,
f
"{exec_time:.2f}秒"
,
""
)
# tool = tool_dict[result["tool"]]
# params = result["params"]
...
...
@@ -128,5 +150,7 @@ def run_examples():
console
.
print
(
table
)
console
.
print
(
"="
*
80
)
print
(
f
"工具选择成功率: {success_count} / {len(test_cases)} = {success_count / len(test_cases) * 100:.2f}
%
"
)
print
(
f
"平均耗时: {total_time / len(test_cases):.2f}秒 总共耗时: {total_time:.2f}秒"
)
if
__name__
==
"__main__"
:
run_examples
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
test/test_agent.py
0 → 100644
View file @
625c5230
import
requests
questions
=
[
"贵阳市各县区设备在线率是多少"
,
"贵阳市各个县区 2024 年 11 月 2 日设备在线率是多少"
,
"各省实时设备在线率排名"
,
"各省实时设备在线率排名中,前 5名是哪些省"
,
"各省实时设备在线率排名中,最后 5 名是哪些省"
,
"各厂商设备在线率排名中,排名前 10的厂商"
,
"在厂商设备在线率排名中,上海华测设备在线率排名是多少"
,
"贵阳市的雨量传感器数量有多少 "
,
"贵阳市的监测点数量是多少"
,
"贵阳市有三维模型的监测点数量是多少"
,
"今年 ,贵阳市每月的设备在线率是多少?(注:如果需要按“每月”维度展示,问题中需添加 “每月” 字段,增加问题区分度)"
,
"2024 年 贵阳市每月的滑坡仪在线率是多少"
,
"贵阳市威胁人数在 200 人以上的滑坡监测点数量有多少"
,
"2024 年 5 月 20 日到 2024 年 11 月 27 日,贵阳市黄色预警消息占比是多少;"
,
"2024 年 10 月 1日到 2024 年 10 月31日,贵阳市处置率是多少"
,
"2024 年5 月到 2024 年11 月,贵阳市每月处置率是多少"
,
"2023 年,贵阳市每月处置率是多少"
,
"贵阳市预警等级为红色的滑坡监测点数量有多少"
]
url
=
"http://localhost:8088/api/agent/rate"
for
question
in
questions
:
print
(
question
)
http_response
=
requests
.
post
(
url
,
json
=
{
"query"
:
question
})
print
(
http_response
.
json
())
print
(
"--------------------------------"
*
5
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment