Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
e0059c27
Commit
e0059c27
authored
Nov 13, 2024
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
工具选择与执行分开,由程序管控流程方案验证
parent
eddec199
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
197 additions
and
1 deletions
+197
-1
api.py
src/controller/api.py
+7
-1
agent_rate.py
src/server/agent_rate.py
+17
-0
tool_picker.py
src/server/tool_picker.py
+96
-0
test_tool_picker.py
test/test_tool_picker.py
+77
-0
No files found.
src/controller/api.py
View file @
e0059c27
...
...
@@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
import
uvicorn
from
src.server.agent_rate
import
new_rate_agent
from
src.server.agent_rate
import
new_rate_agent
,
RateAgentV3
from
src.server.classify
import
new_router_llm
from
src.controller.request
import
GeoAgentRateRequest
...
...
@@ -77,6 +77,7 @@ class AgentManager:
verbose
=
True
)
self
.
agent
=
new_rate_agent
(
self
.
llm
,
verbose
=
True
,
tool_base_url
=
tool_base_url
)
self
.
rate_agent
=
RateAgentV3
(
self
.
llm
,
tool_base_url
=
tool_base_url
)
self
.
router_llm
=
new_router_llm
(
self
.
llm
)
def
get_llm
(
self
):
...
...
@@ -85,6 +86,9 @@ class AgentManager:
def
get_agent
(
self
):
return
self
.
agent
def
get_rate_agent
(
self
):
return
self
.
rate_agent
def
get_router_llm
(
self
):
return
self
.
router_llm
...
...
@@ -93,8 +97,10 @@ agent_manager = AgentManager()
@app.post
(
'/api/agent/rate'
)
def
rate
(
chat_request
:
GeoAgentRateRequest
,
token
:
str
=
Header
(
None
)):
agent
=
agent_manager
.
get_agent
()
rate_agent
=
agent_manager
.
get_rate_agent
()
try
:
res
=
agent
.
exec
(
prompt_args
=
{
"input"
:
chat_request
.
query
})
# res = rate_agent.run(chat_request.query)
except
Exception
as
e
:
print
(
f
"处理请求失败, 错误信息: {str(e)},请重新提问"
)
return
{
...
...
src/server/agent_rate.py
View file @
e0059c27
...
...
@@ -15,6 +15,7 @@ from langchain import hub
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_monitor
import
MonitorPointTool
from
src.server.tool_picker
import
ToolPicker
,
ToolRunner
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
# tools_renderer: ToolsRenderer = render_text_description_and_args,
...
...
@@ -168,4 +169,20 @@ def new_rate_agent(llm, verbose: bool = False,**args):
class
RateAgentV3
:
def
__init__
(
self
,
llm
,
tool_base_url
:
str
):
tools
=
[
RegionRateTool
(
base_url
=
tool_base_url
),
RankingRateTool
(
base_url
=
tool_base_url
),
MonitorPointTool
(
base_url
=
tool_base_url
)
]
self
.
picker
=
ToolPicker
(
llm
,
tools
)
tools_dict
=
{}
for
t
in
tools
:
tools_dict
[
t
.
name
]
=
t
self
.
runner
=
ToolRunner
(
llm
,
tools_dict
)
def
run
(
self
,
input
:
str
):
picker_result
=
self
.
picker
.
pick
(
input
)
return
self
.
runner
.
run
(
input
,
picker_result
[
"tool"
],
picker_result
[
"params"
])
src/server/tool_picker.py
0 → 100644
View file @
e0059c27
from
typing
import
List
,
Dict
from
datetime
import
datetime
from
langchain_core.prompts
import
PromptTemplate
,
ChatPromptTemplate
,
SystemMessagePromptTemplate
,
MessagesPlaceholder
,
HumanMessagePromptTemplate
from
langchain.tools.render
import
ToolsRenderer
,
render_text_description_and_args
from
langchain_core.output_parsers
import
JsonOutputParser
as
JSONOutputParser
from
langchain_core.tools
import
BaseTool
PICKER_SYSTEM_PROMPT
=
"""
你是一个智能工具选择助手,你需要根据用户的问题选择最合适的工具,并提取出工具所需的参数。
工作流程:
1. 分析用户问题,确定所需工具
2. 提取并整理工具所需的参数
3. 返回工具名称和参数
请遵循以下规则:
- 工具名称必须从工具列表中选择: {{tool_names}}
- 返回格式:
```json
{{
"tool": "工具名称",
"params": {{"参数1": "值1", "参数2": "值2"}}
}}
```
工具列表详情:
{tools}
"""
PICKER_HUMAN_PROMPT
=
"""
用户的问题是:{input}
"""
class
ToolPicker
:
def
__init__
(
self
,
llm
,
tools
:
List
):
self
.
tools
=
tools
self
.
llm
=
llm
date_now
=
datetime
.
now
()
.
strftime
(
"
%
Y-
%
m-
%
d"
)
picker_human
=
f
"今天是{date_now}
\n\n
{PICKER_HUMAN_PROMPT}"
prompt
=
ChatPromptTemplate
.
from_messages
([
SystemMessagePromptTemplate
.
from_template
(
PICKER_SYSTEM_PROMPT
),
MessagesPlaceholder
(
variable_name
=
"chat_history"
,
optional
=
True
),
HumanMessagePromptTemplate
.
from_template
(
picker_human
)
])
prompt
=
prompt
.
partial
(
tools
=
render_text_description_and_args
(
tools
))
self
.
chain
=
prompt
|
self
.
llm
|
JSONOutputParser
()
def
pick
(
self
,
input
:
str
):
print
(
input
)
return
self
.
chain
.
invoke
({
"input"
:
input
})
RUNNER_SYSTEM_PROMPT
=
"""
你是一个擅长根据工具执行结果回答用户问题的助手。
工作流程:
1. 分析用户的问题
2. 根据用户的问题,解读工具执行结果,进行简要的分析说明
3. 返回用户问题的答案
请遵循以下规则:
- 工具执行结果中的数据必须使用 markdown 表格展示
- 确保数据的完整性, 不要遗漏数据
- 表格中的数据只能来源于工具执行结果
"""
RUNNER_HUMAN_PROMPT
=
"""
用户的问题是:{input}
工具的执行结果是:{result}
"""
class
ToolRunner
:
def
__init__
(
self
,
llm
,
tools
:
Dict
[
str
,
BaseTool
]):
self
.
tools
=
tools
self
.
llm
=
llm
prompt
=
ChatPromptTemplate
.
from_messages
([
SystemMessagePromptTemplate
.
from_template
(
RUNNER_SYSTEM_PROMPT
),
HumanMessagePromptTemplate
.
from_template
(
RUNNER_HUMAN_PROMPT
)
])
self
.
chain
=
prompt
|
self
.
llm
def
run
(
self
,
input
:
str
,
tool_name
:
str
,
params
:
Dict
):
if
tool_name
not
in
self
.
tools
:
raise
ValueError
(
f
"Tool {tool_name} not found"
)
tool
=
self
.
tools
[
tool_name
]
result
=
tool
.
invoke
(
params
)
return
self
.
chain
.
invoke
({
"input"
:
input
,
"result"
:
result
})
test/test_tool_picker.py
0 → 100644
View file @
e0059c27
import
sys
,
os
sys
.
path
.
append
(
"../"
)
import
pytest
from
unittest.mock
import
Mock
from
langchain_openai
import
ChatOpenAI
from
src.server.tool_picker
import
ToolPicker
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_monitor
import
MonitorPointTool
@pytest.fixture
def
mock_llm
():
# llm = Mock(spec=ChatOpenAI)
# # 模拟 LLM 返回结果
# llm.invoke.return_value = {"content": """{"tool": "RegionRateTool", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}}"""}
# return llm
llm
=
ChatOpenAI
(
openai_api_key
=
"xxxxxx"
,
openai_api_base
=
"http://192.168.10.14:8000/v1"
,
model_name
=
"Qwen2-7B"
,
verbose
=
True
)
return
llm
# 使用参数化测试不同的场景
@pytest.mark.parametrize
(
"query, expected_response, expected_tool"
,
[
(
"请分析下今天全国各地区在线率情况"
,
{
"tool"
:
"region_online_rate"
,
"params"
:
{
"start_time"
:
"2024-11-13"
,
"end_time"
:
"2024-11-13"
,
"region_name"
:
""
}},
"region_online_rate"
),
(
"请分析下今天甘肃省设备在线率情况"
,
{
"tool"
:
"region_online_rate"
,
"params"
:
{
"start_time"
:
"2024-11-13"
,
"end_time"
:
"2024-11-13"
,
"region_name"
:
"甘肃省"
}},
"region_online_rate"
),
(
"查询今年三季度甘肃省设备在线率情况"
,
{
"tool"
:
"region_online_rate"
,
"params"
:
{
"start_time"
:
"2024-07-01"
,
"end_time"
:
"2024-09-30"
,
"region_name"
:
"甘肃省"
}},
"region_online_rate"
),
(
"查询2024年11月13日各地区排名情况"
,
{
"tool"
:
"online_rate_ranking"
,
"params"
:
{
"rate_type"
:
"1"
}},
"online_rate_ranking"
),
(
"查询各厂商在线率排名情况"
,
{
"tool"
:
"online_rate_ranking"
,
"params"
:
{
"rate_type"
:
"2"
}},
"online_rate_ranking"
),
(
"甘肃省监控点的状态如何?"
,
{
"tool"
:
"monitor_points_query"
,
"params"
:
{
"key"
:
"甘肃省"
}},
"monitor_points_query"
),
])
def
test_tool_picker_scenarios
(
mock_llm
,
query
,
expected_response
,
expected_tool
):
# 创建测试用的工具
test_tools
=
[
RegionRateTool
(),
RankingRateTool
(),
MonitorPointTool
()
]
picker
=
ToolPicker
(
mock_llm
,
test_tools
)
result
=
picker
.
pick
(
query
)
# 验证结果
assert
isinstance
(
result
,
dict
)
assert
result
[
"tool"
]
==
expected_tool
assert
"params"
in
result
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment