Commit e0059c27 by tinywell

工具选择与执行分开,由程序管控流程方案验证

parent eddec199
......@@ -9,7 +9,7 @@ from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from src.server.agent_rate import new_rate_agent
from src.server.agent_rate import new_rate_agent, RateAgentV3
from src.server.classify import new_router_llm
from src.controller.request import GeoAgentRateRequest
......@@ -77,6 +77,7 @@ class AgentManager:
verbose=True
)
self.agent = new_rate_agent(self.llm, verbose=True, tool_base_url=tool_base_url)
self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url)
self.router_llm = new_router_llm(self.llm)
def get_llm(self):
......@@ -85,6 +86,9 @@ class AgentManager:
def get_agent(self):
return self.agent
def get_rate_agent(self):
return self.rate_agent
def get_router_llm(self):
return self.router_llm
......@@ -93,8 +97,10 @@ agent_manager = AgentManager()
@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
agent = agent_manager.get_agent()
rate_agent = agent_manager.get_rate_agent()
try:
res = agent.exec(prompt_args={"input": chat_request.query})
# res = rate_agent.run(chat_request.query)
except Exception as e:
print(f"处理请求失败, 错误信息: {str(e)},请重新提问")
return {
......
......@@ -15,6 +15,7 @@ from langchain import hub
from src.agent.tool_rate import RegionRateTool,RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
from src.server.tool_picker import ToolPicker, ToolRunner
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
# tools_renderer: ToolsRenderer = render_text_description_and_args,
......@@ -168,4 +169,20 @@ def new_rate_agent(llm, verbose: bool = False,**args):
class RateAgentV3:
def __init__(self, llm, tool_base_url: str):
tools = [
RegionRateTool(base_url=tool_base_url),
RankingRateTool(base_url=tool_base_url),
MonitorPointTool(base_url=tool_base_url)
]
self.picker = ToolPicker(llm, tools)
tools_dict = {}
for t in tools:
tools_dict[t.name] = t
self.runner = ToolRunner(llm, tools_dict)
def run(self, input: str):
picker_result = self.picker.pick(input)
return self.runner.run(input, picker_result["tool"], picker_result["params"])
from typing import List, Dict
from datetime import datetime
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain.tools.render import ToolsRenderer, render_text_description_and_args
from langchain_core.output_parsers import JsonOutputParser as JSONOutputParser
from langchain_core.tools import BaseTool
PICKER_SYSTEM_PROMPT = """
你是一个智能工具选择助手,你需要根据用户的问题选择最合适的工具,并提取出工具所需的参数。
工作流程:
1. 分析用户问题,确定所需工具
2. 提取并整理工具所需的参数
3. 返回工具名称和参数
请遵循以下规则:
- 工具名称必须从工具列表中选择: {{tool_names}}
- 返回格式:
```json
{{
"tool": "工具名称",
"params": {{"参数1": "值1", "参数2": "值2"}}
}}
```
工具列表详情:
{tools}
"""
PICKER_HUMAN_PROMPT = """
用户的问题是:{input}
"""
class ToolPicker:
def __init__(self, llm, tools: List):
self.tools = tools
self.llm = llm
date_now = datetime.now().strftime("%Y-%m-%d")
picker_human = f"今天是{date_now}\n\n{PICKER_HUMAN_PROMPT}"
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(PICKER_SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="chat_history", optional=True),
HumanMessagePromptTemplate.from_template(picker_human)
])
prompt = prompt.partial(tools=render_text_description_and_args(tools))
self.chain = prompt | self.llm | JSONOutputParser()
def pick(self, input: str):
print(input)
return self.chain.invoke({"input": input})
RUNNER_SYSTEM_PROMPT = """
你是一个擅长根据工具执行结果回答用户问题的助手。
工作流程:
1. 分析用户的问题
2. 根据用户的问题,解读工具执行结果,进行简要的分析说明
3. 返回用户问题的答案
请遵循以下规则:
- 工具执行结果中的数据必须使用 markdown 表格展示
- 确保数据的完整性, 不要遗漏数据
- 表格中的数据只能来源于工具执行结果
"""
RUNNER_HUMAN_PROMPT = """
用户的问题是:{input}
工具的执行结果是:{result}
"""
class ToolRunner:
def __init__(self, llm, tools: Dict[str, BaseTool]):
self.tools = tools
self.llm = llm
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(RUNNER_SYSTEM_PROMPT),
HumanMessagePromptTemplate.from_template(RUNNER_HUMAN_PROMPT)
])
self.chain = prompt | self.llm
def run(self, input: str, tool_name: str, params: Dict):
if tool_name not in self.tools:
raise ValueError(f"Tool {tool_name} not found")
tool = self.tools[tool_name]
result = tool.invoke(params)
return self.chain.invoke({"input": input, "result": result})
import sys,os
sys.path.append("../")
import pytest
from unittest.mock import Mock
from langchain_openai import ChatOpenAI
from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool,RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
@pytest.fixture
def mock_llm():
# llm = Mock(spec=ChatOpenAI)
# # 模拟 LLM 返回结果
# llm.invoke.return_value = {"content": """{"tool": "RegionRateTool", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}}"""}
# return llm
llm = ChatOpenAI(
openai_api_key="xxxxxx",
openai_api_base="http://192.168.10.14:8000/v1",
model_name="Qwen2-7B",
verbose=True
)
return llm
# 使用参数化测试不同的场景
@pytest.mark.parametrize("query, expected_response, expected_tool", [
(
"请分析下今天全国各地区在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}},
"region_online_rate"
),
(
"请分析下今天甘肃省设备在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": "甘肃省"}},
"region_online_rate"
),
(
"查询今年三季度甘肃省设备在线率情况",
{"tool": "region_online_rate", "params": {"start_time": "2024-07-01", "end_time": "2024-09-30", "region_name": "甘肃省"}},
"region_online_rate"
),
(
"查询2024年11月13日各地区排名情况",
{"tool": "online_rate_ranking", "params": {"rate_type": "1"}},
"online_rate_ranking"
),
(
"查询各厂商在线率排名情况",
{"tool": "online_rate_ranking", "params": {"rate_type": "2"}},
"online_rate_ranking"
),
(
"甘肃省监控点的状态如何?",
{"tool": "monitor_points_query", "params": {"key": "甘肃省"}},
"monitor_points_query"
),
])
def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool):
# 创建测试用的工具
test_tools = [
RegionRateTool(),
RankingRateTool(),
MonitorPointTool()
]
picker = ToolPicker(mock_llm, test_tools)
result = picker.pick(query)
# 验证结果
assert isinstance(result, dict)
assert result["tool"] == expected_tool
assert "params" in result
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment