Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
625c5230
Commit
625c5230
authored
Dec 09, 2024
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
优化 warn 工具描述,降低与在线率工具混淆
parent
abcc91a0
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
139 additions
and
15 deletions
+139
-15
tool_warn.py
src/agent/tool_warn.py
+2
-2
run_tool_picker_rate.py
test/run_tool_picker_rate.py
+80
-11
run_tool_picker_warn.py
test/run_tool_picker_warn.py
+27
-2
test_agent.py
test/test_agent.py
+30
-0
No files found.
src/agent/tool_warn.py
View file @
625c5230
...
@@ -22,12 +22,12 @@ class WarningArgs(BaseModel):
...
@@ -22,12 +22,12 @@ class WarningArgs(BaseModel):
start_time
:
str
=
Field
(
""
,
description
=
"开始时间 (YYYY-MM-DD HH:mm:ss)"
)
start_time
:
str
=
Field
(
""
,
description
=
"开始时间 (YYYY-MM-DD HH:mm:ss)"
)
end_time
:
str
=
Field
(
""
,
description
=
"结束时间 (YYYY-MM-DD HH:mm:ss)"
)
end_time
:
str
=
Field
(
""
,
description
=
"结束时间 (YYYY-MM-DD HH:mm:ss)"
)
region_name
:
str
=
Field
(
""
,
description
=
"地区名称,如果要查询全国数据,请输入空字符串"
)
region_name
:
str
=
Field
(
""
,
description
=
"地区名称,如果要查询全国数据,请输入空字符串"
)
query_type
:
str
=
Field
(
"1"
,
description
=
"查询类型,1 表示查询一段时间内的综合数据,2 表示查询指定年份全年按月度统计的虚警
率,处置率
等信息"
)
query_type
:
str
=
Field
(
"1"
,
description
=
"查询类型,1 表示查询一段时间内的综合数据,2 表示查询指定年份全年按月度统计的虚警
、处置
等信息"
)
class
WarningTool
(
BaseTool
):
class
WarningTool
(
BaseTool
):
"""查询预警处置和虚警情况"""
"""查询预警处置和虚警情况"""
name
:
str
=
"warning_statistics"
name
:
str
=
"warning_statistics"
description
:
str
=
"查询一定时间范围内不同地区预警处置和虚警情况,包括处置
率、虚警率、蓝黄橙红数数量和占比统计。也支持查询指定年份全年按月度统计的虚警率,处置率
等信息。"
description
:
str
=
"查询一定时间范围内不同地区预警处置和虚警情况,包括处置
情况、虚警情况、蓝黄橙红数数量等统计。也支持查询指定年份全年按月度统计的虚警情况,处置情况
等信息。"
args_schema
:
Type
[
BaseModel
]
=
WarningArgs
args_schema
:
Type
[
BaseModel
]
=
WarningArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
client
:
Any
=
Field
(
None
,
exclude
=
True
)
logger
:
logging
.
Logger
=
Field
(
None
,
exclude
=
True
)
logger
:
logging
.
Logger
=
Field
(
None
,
exclude
=
True
)
...
...
test/run_tool_picker_rate.py
View file @
625c5230
from
langchain_openai
import
ChatOpenAI
from
langchain_openai
import
ChatOpenAI
from
rich.console
import
Console
from
rich.console
import
Console
from
rich.table
import
Table
from
rich.table
import
Table
import
time
import
sys
,
os
import
sys
,
os
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
from
src.server.tool_picker
import
ToolPicker
from
src.server.tool_picker
import
ToolPicker
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_monitor
import
MonitorPointTool
from
src.agent.tool_monitor
import
MonitorPointTool
from
src.agent.tool_warn
import
WarningTool
def
run_examples
():
def
create_tool_picker
():
# 初始化 rich console
console
=
Console
()
# 初始化 LLM
llm
=
ChatOpenAI
(
llm
=
ChatOpenAI
(
openai_api_key
=
"xxxxxx"
,
openai_api_key
=
"xxxxxx"
,
openai_api_base
=
"http://192.168.10.14:8000/v1"
,
openai_api_base
=
"http://192.168.10.14:8000/v1"
,
model_name
=
"Qwen2-7B"
,
model_name
=
"Qwen2-7B"
,
verbose
=
True
verbose
=
True
)
)
base_url
=
"http://172.30.0.37:30007"
# 初始化工具
tools
=
[
RegionRateTool
(
base_url
=
base_url
),
RankingRateTool
(
base_url
=
base_url
),
MonitorPointTool
(
base_url
=
base_url
),
WarningTool
(
base_url
=
base_url
),
]
return
ToolPicker
(
llm
,
tools
)
def
run_examples
():
# 初始化 rich console
console
=
Console
()
# 初始化 LLM
base_url
=
"http://172.30.0.37:30007"
base_url
=
"http://172.30.0.37:30007"
# 初始化工具
# 初始化工具
tools
=
[
tools
=
[
RegionRateTool
(
base_url
=
base_url
),
RegionRateTool
(
base_url
=
base_url
),
RankingRateTool
(
base_url
=
base_url
),
RankingRateTool
(
base_url
=
base_url
),
MonitorPointTool
(
base_url
=
base_url
),
MonitorPointTool
(
base_url
=
base_url
),
WarningTool
(
base_url
=
base_url
),
]
]
tool_dict
=
{
tool
.
name
:
tool
for
tool
in
tools
}
tool_dict
=
{
tool
.
name
:
tool
for
tool
in
tools
}
...
@@ -193,9 +210,20 @@ def run_examples():
...
@@ -193,9 +210,20 @@ def run_examples():
"start_time"
:
"2023-01-01"
,
"end_time"
:
"2023-12-31"
,
"region_name"
:
"西藏"
,
"month_statistics"
:
True
"start_time"
:
"2023-01-01"
,
"end_time"
:
"2023-12-31"
,
"region_name"
:
"西藏"
,
"month_statistics"
:
True
}
}
}
}
},
{
"query"
:
"贵阳市各县区设备在线率是多少"
,
"expected"
:
{
"tool"
:
"region_online_rate"
,
"params"
:
{
"region_name"
:
"贵阳市"
}
}
}
}
]
]
tool_selected_success
=
0
# 为每个测试案例创建一个表格
# 为每个测试案例创建一个表格
for
i
,
case
in
enumerate
(
test_cases
,
1
):
for
i
,
case
in
enumerate
(
test_cases
,
1
):
console
.
print
(
f
"
\n
[bold cyan]=== 测试案例 {i} ===[/bold cyan]"
)
console
.
print
(
f
"
\n
[bold cyan]=== 测试案例 {i} ===[/bold cyan]"
)
...
@@ -219,6 +247,9 @@ def run_examples():
...
@@ -219,6 +247,9 @@ def run_examples():
"✓"
if
expected_tool
==
actual_tool
else
"✗"
"✓"
if
expected_tool
==
actual_tool
else
"✗"
)
)
if
expected_tool
==
actual_tool
:
tool_selected_success
+=
1
# 添加参数比较行
# 添加参数比较行
for
param_key
in
case
[
"expected"
][
"params"
]:
for
param_key
in
case
[
"expected"
][
"params"
]:
expected_value
=
str
(
case
[
"expected"
][
"params"
][
param_key
])
expected_value
=
str
(
case
[
"expected"
][
"params"
][
param_key
])
...
@@ -233,14 +264,53 @@ def run_examples():
...
@@ -233,14 +264,53 @@ def run_examples():
tool
=
tool_dict
[
result
[
"tool"
]]
tool
=
tool_dict
[
result
[
"tool"
]]
params
=
result
[
"params"
]
params
=
result
[
"params"
]
result
=
tool
.
invoke
(
params
)
#
result = tool.invoke(params)
print
(
result
)
#
print(result)
except
Exception
as
e
:
except
Exception
as
e
:
table
.
add_row
(
"错误"
,
""
,
str
(
e
),
"✗"
)
table
.
add_row
(
"错误"
,
""
,
str
(
e
),
"✗"
)
console
.
print
(
table
)
console
.
print
(
table
)
console
.
print
(
"="
*
80
)
console
.
print
(
"="
*
80
)
print
(
f
"工具选择成功率: {tool_selected_success} / {len(test_cases)} = {tool_selected_success / len(test_cases) * 100:.2f}
%
"
)
def
run_case_17
():
console
=
Console
()
picker
=
create_tool_picker
()
success_count
=
0
total_time
=
0
test_cases_count
=
50
for
i
in
range
(
test_cases_count
):
start_time
=
time
.
time
()
query
=
"贵阳市各县区设备在线率是多少"
expected
=
{
"tool"
:
"region_online_rate"
,
"params"
:
{
"region_name"
:
"贵阳市"
}
}
result
=
picker
.
pick
(
query
)
exec_time
=
time
.
time
()
-
start_time
total_time
+=
exec_time
table
=
Table
(
title
=
f
"第 {i+1:02d} 次测试结果 【{query}】"
)
table
.
add_column
(
"查询"
,
style
=
"cyan"
)
table
.
add_column
(
"预期工具"
,
style
=
"green"
)
table
.
add_column
(
"实际工具"
,
style
=
"yellow"
)
table
.
add_column
(
"是否匹配"
,
style
=
"magenta"
)
table
.
add_column
(
"耗时"
,
style
=
"magenta"
)
table
.
add_row
(
query
,
expected
[
"tool"
],
result
[
"tool"
],
"✓"
if
result
[
"tool"
]
==
expected
[
"tool"
]
else
"✗"
,
f
"{exec_time:.2f}秒"
)
if
result
[
"tool"
]
==
expected
[
"tool"
]:
success_count
+=
1
console
.
print
(
table
)
print
(
f
"工具选择成功率: {success_count} / {test_cases_count} = {success_count / test_cases_count * 100:.2f}
%
"
)
print
(
f
"平均耗时: {total_time / test_cases_count:.2f}秒 总共耗时: {total_time:.2f}秒"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
run_examples
()
# run_examples()
\ No newline at end of file
run_case_17
()
test/run_tool_picker_warn.py
View file @
625c5230
from
langchain_openai
import
ChatOpenAI
from
langchain_openai
import
ChatOpenAI
from
rich.console
import
Console
from
rich.console
import
Console
from
rich.table
import
Table
from
rich.table
import
Table
import
time
import
sys
,
os
import
sys
,
os
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
...
@@ -79,9 +79,23 @@ def run_examples():
...
@@ -79,9 +79,23 @@ def run_examples():
"query_type"
:
"1"
,
"query_type"
:
"1"
,
}
}
}
}
},
{
"query"
:
"2023 年,贵阳市每月处置率是多少"
,
"expected"
:
{
"tool"
:
"warning_statistics"
,
"params"
:
{
"start_time"
:
"2023-01-01 00:00:00"
,
"end_time"
:
"2023-12-31 23:59:59"
,
"region_name"
:
"贵阳市"
,
"query_type"
:
"2"
,
}
}
}
}
]
]
success_count
=
0
total_time
=
0
# 为每个测试案例创建一个表格
# 为每个测试案例创建一个表格
for
i
,
case
in
enumerate
(
test_cases
,
1
):
for
i
,
case
in
enumerate
(
test_cases
,
1
):
console
.
print
(
f
"
\n
[bold cyan]=== 测试案例 {i} ===[/bold cyan]"
)
console
.
print
(
f
"
\n
[bold cyan]=== 测试案例 {i} ===[/bold cyan]"
)
...
@@ -91,9 +105,13 @@ def run_examples():
...
@@ -91,9 +105,13 @@ def run_examples():
table
.
add_column
(
"预期结果"
,
style
=
"green"
)
table
.
add_column
(
"预期结果"
,
style
=
"green"
)
table
.
add_column
(
"实际结果"
,
style
=
"yellow"
)
table
.
add_column
(
"实际结果"
,
style
=
"yellow"
)
table
.
add_column
(
"是否匹配"
,
style
=
"magenta"
)
table
.
add_column
(
"是否匹配"
,
style
=
"magenta"
)
table
.
add_column
(
"耗时"
,
style
=
"magenta"
)
try
:
try
:
start_time
=
time
.
time
()
result
=
picker
.
pick
(
case
[
"query"
])
result
=
picker
.
pick
(
case
[
"query"
])
end_time
=
time
.
time
()
exec_time
=
end_time
-
start_time
total_time
+=
exec_time
# 添加工具比较行
# 添加工具比较行
expected_tool
=
case
[
"expected"
][
"tool"
]
expected_tool
=
case
[
"expected"
][
"tool"
]
...
@@ -115,6 +133,10 @@ def run_examples():
...
@@ -115,6 +133,10 @@ def run_examples():
actual_value
,
actual_value
,
"✓"
if
expected_value
==
actual_value
else
"✗"
"✓"
if
expected_value
==
actual_value
else
"✗"
)
)
if
expected_tool
==
actual_tool
:
success_count
+=
1
table
.
add_row
(
"耗时"
,
""
,
f
"{exec_time:.2f}秒"
,
""
)
# tool = tool_dict[result["tool"]]
# tool = tool_dict[result["tool"]]
# params = result["params"]
# params = result["params"]
...
@@ -128,5 +150,7 @@ def run_examples():
...
@@ -128,5 +150,7 @@ def run_examples():
console
.
print
(
table
)
console
.
print
(
table
)
console
.
print
(
"="
*
80
)
console
.
print
(
"="
*
80
)
print
(
f
"工具选择成功率: {success_count} / {len(test_cases)} = {success_count / len(test_cases) * 100:.2f}
%
"
)
print
(
f
"平均耗时: {total_time / len(test_cases):.2f}秒 总共耗时: {total_time:.2f}秒"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
run_examples
()
run_examples
()
\ No newline at end of file
test/test_agent.py
0 → 100644
View file @
625c5230
import
requests
questions
=
[
"贵阳市各县区设备在线率是多少"
,
"贵阳市各个县区 2024 年 11 月 2 日设备在线率是多少"
,
"各省实时设备在线率排名"
,
"各省实时设备在线率排名中,前 5名是哪些省"
,
"各省实时设备在线率排名中,最后 5 名是哪些省"
,
"各厂商设备在线率排名中,排名前 10的厂商"
,
"在厂商设备在线率排名中,上海华测设备在线率排名是多少"
,
"贵阳市的雨量传感器数量有多少 "
,
"贵阳市的监测点数量是多少"
,
"贵阳市有三维模型的监测点数量是多少"
,
"今年 ,贵阳市每月的设备在线率是多少?(注:如果需要按“每月”维度展示,问题中需添加 “每月” 字段,增加问题区分度)"
,
"2024 年 贵阳市每月的滑坡仪在线率是多少"
,
"贵阳市威胁人数在 200 人以上的滑坡监测点数量有多少"
,
"2024 年 5 月 20 日到 2024 年 11 月 27 日,贵阳市黄色预警消息占比是多少;"
,
"2024 年 10 月 1日到 2024 年 10 月31日,贵阳市处置率是多少"
,
"2024 年5 月到 2024 年11 月,贵阳市每月处置率是多少"
,
"2023 年,贵阳市每月处置率是多少"
,
"贵阳市预警等级为红色的滑坡监测点数量有多少"
]
url
=
"http://localhost:8088/api/agent/rate"
for
question
in
questions
:
print
(
question
)
http_response
=
requests
.
post
(
url
,
json
=
{
"query"
:
question
})
print
(
http_response
.
json
())
print
(
"--------------------------------"
*
5
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment