tool_picker.py 5.87 KB
Newer Older
1 2
from typing import List, Dict
from datetime import datetime
tinywell committed
3
import time
4 5

from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
文靖昊 committed
6
from langchain.tools.render import  render_text_description_and_args
7 8
from langchain_core.output_parsers import JsonOutputParser as JSONOutputParser
from langchain_core.tools import BaseTool
9
from ..utils.logger import get_logger
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41

PICKER_SYSTEM_PROMPT = """
你是一个智能工具选择助手,你需要根据用户的问题选择最合适的工具,并提取出工具所需的参数。

工作流程:
1. 分析用户问题,确定所需工具
2. 提取并整理工具所需的参数
3. 返回工具名称和参数

请遵循以下规则:
- 工具名称必须从工具列表中选择: {{tool_names}}
- 返回格式:
    ```json
    {{
        "tool": "工具名称", 
        "params": {{"参数1": "值1", "参数2": "值2"}}
    }}
    ```

工具列表详情:
{tools}

""" 

PICKER_HUMAN_PROMPT = """
用户的问题是:{input}
"""

class ToolPicker:
    def __init__(self, llm, tools: List):
        self.tools = tools
        self.llm = llm
42 43
        self.logger = get_logger("ToolPicker")
        
44 45 46 47 48 49 50 51 52 53 54 55
        date_now = datetime.now().strftime("%Y-%m-%d")
        picker_human = f"今天是{date_now}\n\n{PICKER_HUMAN_PROMPT}"
        prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(PICKER_SYSTEM_PROMPT),
            MessagesPlaceholder(variable_name="chat_history", optional=True),
            HumanMessagePromptTemplate.from_template(picker_human)
        ])
        prompt = prompt.partial(tools=render_text_description_and_args(tools))

        self.chain = prompt | self.llm | JSONOutputParser()

    def pick(self, input: str):
56 57 58 59 60 61 62 63
        self.logger.info(f"Received input: {input}")
        try:
            result = self.chain.invoke({"input": input})
            self.logger.info(f"Selected tool: {result['tool']} with params: {result['params']}")
            return result
        except Exception as e:
            self.logger.error(f"Error picking tool: {str(e)}")
            raise
64

65 66 67


RUNNER_SYSTEM_PROMPT_V1 = """
68 69 70
你是一个擅长根据工具执行结果回答用户问题的助手。

工作流程:
tinywell committed
71 72
1. 对获取的数据进行简要的分析和解读
2. 结合数据对用户问题进行回答
73

tinywell committed
74 75 76
注意:
- 不要展示原数据,只展示分析和解读后的结果
- 不要提到工具调用
77 78
"""

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
# "这一版的问答效果相对来说比较好" -- wukai
RUNNER_SYSTEM_PROMPT_V0 = """
你是一个擅长根据工具执行结果回答用户问题的助手。

工作流程:
1. 分析用户的问题
2. 根据用户的问题,解读工具执行结果,进行简要的分析说明
3. 返回用户问题的答案

请遵循以下规则:
- 工具执行结果中的数据必须使用 markdown 表格展示
- 确保数据的完整性, 不要遗漏数据
- 表格中的数据只能来源于工具执行结果
"""

94 95 96 97 98 99 100 101

RUNNER_HUMAN_PROMPT = """
用户的问题是:{input}
工具的执行结果是:{result}
"""


class ToolRunner:
102
    def __init__(self, llm, tools: Dict[str, BaseTool],version: str = "v1"):
103 104
        self.tools = tools
        self.llm = llm
105
        self.logger = get_logger("ToolRunner")
106 107 108 109
        self.version = version
        system_prompt = RUNNER_SYSTEM_PROMPT_V1
        if version != "v1":
            system_prompt = RUNNER_SYSTEM_PROMPT_V0
110
        prompt = ChatPromptTemplate.from_messages([
111
            SystemMessagePromptTemplate.from_template(system_prompt),
112 113 114 115 116
            HumanMessagePromptTemplate.from_template(RUNNER_HUMAN_PROMPT)
        ])
        self.chain = prompt | self.llm 

    def run(self, input: str, tool_name: str, params: Dict):
tinywell committed
117 118
        start_time = time.time()
        self.logger.info(f"开始执行工具 '{tool_name}',参数: {params}")
119
        
120
        if tool_name not in self.tools:
tinywell committed
121
            self.logger.error(f"工具 {tool_name} 未找到")
122
            raise ValueError(f"Tool {tool_name} not found")
123 124
            
        try:
tinywell committed
125 126
            # 工具执行
            tool_start = time.time()
127
            tool = self.tools[tool_name]
tinywell committed
128
            self.logger.info(f"调用工具 {tool_name}")
129
            result = tool.invoke(params)
tinywell committed
130 131 132
            tool_time = time.time() - tool_start
            self.logger.info(f"工具执行完成,耗时: {tool_time:.2f}秒")
            self.logger.debug(f"工具执行结果: {result}")
133
            
tinywell committed
134
            # 提取 markdown 表格
135 136 137 138
            table = ""
            if "markdown" in result:
                table = result["markdown"]
                del result["markdown"]
tinywell committed
139 140 141 142
            
            # LLM 解释结果
            llm_start = time.time()
            self.logger.info("开始 LLM 结果解释")
143
            llm_result = self.chain.invoke({"input": input, "result": result})
tinywell committed
144 145
            llm_time = time.time() - llm_start
            self.logger.info(f"LLM 解释完成,耗时: {llm_time:.2f}秒")
146
            
tinywell committed
147
            # 组装响应
148 149 150 151
            response = {
                "output": llm_result.content,
                "table": table
            }
tinywell committed
152 153 154 155 156 157 158 159 160 161 162 163
            
            total_time = time.time() - start_time
            self.logger.info(f"工具执行完成,总耗时: {total_time:.2f}秒,其中工具执行: {tool_time:.2f}秒,LLM解释: {llm_time:.2f}秒")
            
            # 记录性能指标
            self.logger.info(
                "性能指标 - "
                f"总耗时: {total_time:.2f}秒, "
                f"工具执行: {tool_time:.2f}秒 ({(tool_time/total_time)*100:.1f}%), "
                f"LLM解释: {llm_time:.2f}秒 ({(llm_time/total_time)*100:.1f}%)"
            )
            
164 165 166
            return response
            
        except Exception as e:
tinywell committed
167 168
            error_time = time.time() - start_time
            self.logger.error(f"工具执行失败,耗时: {error_time:.2f}秒,错误: {str(e)}", exc_info=True)
169
            raise