from typing import List, Dict
from datetime import datetime
import time
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain.tools.render import render_text_description_and_args
from langchain_core.output_parsers import JsonOutputParser as JSONOutputParser
from langchain_core.tools import BaseTool
from ..utils.logger import get_logger
PICKER_SYSTEM_PROMPT = """
你是一个智能工具选择助手,你需要根据用户的问题选择最合适的工具,并提取出工具所需的参数。
工作流程:
1. 分析用户问题,确定所需工具
2. 提取并整理工具所需的参数
3. 返回工具名称和参数
请遵循以下规则:
- 工具名称必须从工具列表中选择: {{tool_names}}
- 返回格式:
```json
{{
"tool": "工具名称",
"params": {{"参数1": "值1", "参数2": "值2"}}
}}
```
工具列表详情:
{tools}
"""
PICKER_HUMAN_PROMPT = """
用户的问题是:{input}
"""
class ToolPicker:
def __init__(self, llm, tools: List):
self.tools = tools
self.llm = llm
self.logger = get_logger("ToolPicker")
date_now = datetime.now().strftime("%Y-%m-%d")
picker_human = f"今天是{date_now}\n\n{PICKER_HUMAN_PROMPT}"
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(PICKER_SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="chat_history", optional=True),
HumanMessagePromptTemplate.from_template(picker_human)
])
prompt = prompt.partial(tools=render_text_description_and_args(tools))
self.chain = prompt | self.llm | JSONOutputParser()
def pick(self, input: str):
self.logger.info(f"Received input: {input}")
try:
result = self.chain.invoke({"input": input})
self.logger.info(f"Selected tool: {result['tool']} with params: {result['params']}")
return result
except Exception as e:
self.logger.error(f"Error picking tool: {str(e)}")
raise
RUNNER_SYSTEM_PROMPT_V1 = """
你是一个擅长根据工具执行结果回答用户问题的助手。
工作流程:
1. 对获取的数据进行简要的分析和解读
2. 结合数据对用户问题进行回答
注意:
- 不要展示原数据,只展示分析和解读后的结果
- 不要提到工具调用
"""
# "这一版的问答效果相对来说比较好" -- wukai
RUNNER_SYSTEM_PROMPT_V0 = """
你是一个擅长根据工具执行结果回答用户问题的助手。
工作流程:
1. 分析用户的问题
2. 根据用户的问题,解读工具执行结果,进行简要的分析说明
3. 返回用户问题的答案
请遵循以下规则:
- 工具执行结果中的数据必须使用 markdown 表格展示
- 确保数据的完整性, 不要遗漏数据
- 表格中的数据只能来源于工具执行结果,不得添加或修改
- 涉及数量的问题(如个数、总数等),必须严格按照工具返回的结果数据回答,不得进行任何推断或估算
- 如果工具结果中没有相关数据,应明确告知无法得知具体数量,而不是给出推测的数字
"""
RUNNER_HUMAN_PROMPT = """
用户的问题是:{input}
工具的执行结果是:{result}
"""
class ToolRunner:
def __init__(self, llm, tools: Dict[str, BaseTool],version: str = "v1"):
self.tools = tools
self.llm = llm
self.logger = get_logger("ToolRunner")
self.version = version
system_prompt = RUNNER_SYSTEM_PROMPT_V1
if version != "v1":
system_prompt = RUNNER_SYSTEM_PROMPT_V0
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(system_prompt),
HumanMessagePromptTemplate.from_template(RUNNER_HUMAN_PROMPT)
])
self.chain = prompt | self.llm
def run(self, input: str, tool_name: str, params: Dict):
start_time = time.time()
self.logger.info(f"开始执行工具 '{tool_name}',参数: {params}")
if tool_name not in self.tools:
self.logger.error(f"工具 {tool_name} 未找到")
raise ValueError(f"Tool {tool_name} not found")
try:
# 工具执行
tool_start = time.time()
tool = self.tools[tool_name]
self.logger.info(f"调用工具 {tool_name}")
result = tool.invoke(params)
tool_time = time.time() - tool_start
self.logger.info(f"工具执行完成,耗时: {tool_time:.2f}秒")
self.logger.debug(f"工具执行结果: {result}")
# 提取 markdown 表格
table = ""
if "markdown" in result:
table = result["markdown"]
del result["markdown"]
if "summary" in result:
summary = result["summary"]
result["summary"] = f"针对问题 {input},我们{summary}"
# LLM 解释结果
llm_start = time.time()
self.logger.info("开始 LLM 结果解释")
llm_result = self.chain.invoke({"input": input, "result": result})
llm_time = time.time() - llm_start
self.logger.info(f"LLM 解释完成,耗时: {llm_time:.2f}秒")
# 组装响应
response = {
"output": llm_result.content,
"table": table
}
total_time = time.time() - start_time
self.logger.info(f"工具执行完成,总耗时: {total_time:.2f}秒,其中工具执行: {tool_time:.2f}秒,LLM解释: {llm_time:.2f}秒")
# 记录性能指标
self.logger.info(
"性能指标 - "
f"总耗时: {total_time:.2f}秒, "
f"工具执行: {tool_time:.2f}秒 ({(tool_time/total_time)*100:.1f}%), "
f"LLM解释: {llm_time:.2f}秒 ({(llm_time/total_time)*100:.1f}%)"
)
return response
except Exception as e:
error_time = time.time() - start_time
self.logger.error(f"工具执行失败,耗时: {error_time:.2f}秒,错误: {str(e)}", exc_info=True)
raise