agent_rate.py 7.45 KB
Newer Older
1
from typing import Any, List, Sequence, Union
tinywell committed
2
from datetime import datetime
3 4 5 6 7 8 9 10 11 12 13 14 15
import langchain_core
from langchain.tools import BaseTool
from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessagePromptTemplate,MessagesPlaceholder,HumanMessagePromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.agents import AgentExecutor, Agent, create_tool_calling_agent,create_openai_functions_agent,create_structured_chat_agent
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents.format_scratchpad.openai_tools import (
    format_to_openai_tool_messages,
)
from langchain import hub


tinywell committed
16
from src.agent.tool_rate import RegionRateTool,RankingRateTool
17
from src.agent.tool_monitor import MonitorPointTool
18
from src.agent.tool_warn import WarningTool
19
from src.server.tool_picker import ToolPicker, ToolRunner
20

tinywell committed
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None, 
#     tools_renderer: ToolsRenderer = render_text_description_and_args,
#     verbose: bool = False,**args):
#     missing_vars = {"tools", "tool_names", "agent_scratchpad"}.difference(
#         prompt.input_variables + list(prompt.partial_variables)
#     )
#     if missing_vars:
#         raise ValueError(f"Prompt missing required variables: {missing_vars}")

#     prompt = prompt.partial(
#         tools=tools_renderer(list(tools)),
#         tool_names=", ".join([t.name for t in tools]),
#     )
#     if stop_sequence:
#         stop = ["\nObservation"] if stop_sequence is True else stop_sequence
#         llm_with_stop = llm.bind(stop=stop)
#     else:
#         llm_with_stop = llm

#     agent = (
#         RunnablePassthrough.assign(
#             agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
#         )
#         | prompt
#         | llm_with_stop
#     )
#     return agent
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67


class RateAgent:
    def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args):
        # if not prompt:
        #     raise ValueError("PromptTemplate is required")

        prompt = hub.pull("hwchase17/openai-tools-agent")
        agent = create_tool_calling_agent(llm, tools, prompt)

        self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose)    
        
    def exec(self, prompt_args: dict = {}, stream: bool = False):
        return self.agent_executor.invoke(input=prompt_args)

    def stream(self, prompt_args: dict = {}):
        for step in self.agent_executor.stream(prompt_args):
            yield step


68
# 适配 structured_chat_agent 的 prompt
tinywell committed
69
ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及监测设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。
70

tinywell committed
71 72 73 74 75 76 77 78 79 80 81 82 83
核心工作流程:
1. 仔细分析用户问题,准确识别所需的分析类型和必要参数
2. 一次性收集所有必要参数 - 如参数不完整,立即要求用户补充,不进行任何工具调用
3. 确认参数完整后,仅调用一次最合适的分析工具
4. 收到工具返回数据后,必须使用 markdown 表格格式完整展示所有数据,不遗漏任何字段
5. 在表格后提供简洁的数据解读

严格遵守的规则:
- 禁止重复调用工具获取相同维度的数据
- 工具返回的数据必须且只能用 markdown 表格展示,不使用其他格式
- 表格中的数据必须完整展示,不省略任何字段
- 时间格式:YYYY-MM-DD
- 地区名称必须包含行政级别(如:福建省、厦门市)
84 85
- 百分比数据保留两位小数

tinywell committed
86 87 88 89
异常处理:
- 参数缺失:立即提示用户补充具体参数,不进行工具调用
- 数据异常:仅使用工具返回的实际数据,不做假设或补充

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
您可以使用以下工具:

{tools}

使用 JSON 对象指定工具,提供一个 action 键(工具名称)和一个 action_input 键(工具输入) 。

有效的 "action" 值: "Final Answer" 或 {tool_names}

每个 $JSON_BLOB 只提供一个操作,如下所示:
    ``` 
    {{
    "action": $TOOL_NAME,
    "action_input": $INPUT,
    }}
    ```

按照以下格式:

Question: 输入要回答的问题
Thought: 考虑前后步骤
Action:
    ```
    $JSON_BLOB
    ```
Observation: 操作结果
...(重复 Thought/Action/Observation N 次)
Thought: 我知道如何回复
Action:
    ```
    {{
    "action": "Final Answer",
    "action_input": "最终回复给人类",
    }}
    ```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。
你的回复格式为 Action:```$JSON_BLOB```然后 Observation。
"""

tinywell committed
128

129
PROMPT_AGENT_HUMAN = """{input}\n\n {agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复。工具返回的数据必须使用表格展示,包含在最终输出中,并且要保证数据的完整性)"""
130 131 132 133
PROMPT_AGENT_SYS_VARS = [ "tool_names", "tools"]

class RateAgentV2:
    def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args):
tinywell committed
134 135
        date_now = datetime.now().strftime("%Y-%m-%d")
        prompt_human = f"{PROMPT_AGENT_HUMAN}\n\n今天是{date_now}"
136
        prompt = ChatPromptTemplate.from_messages([
137 138
            SystemMessagePromptTemplate.from_template(ONLINE_RATE_SYSTEM_PROMPT),
            MessagesPlaceholder(variable_name="chat_history", optional=True),
tinywell committed
139
            HumanMessagePromptTemplate.from_template(prompt_human)
140 141 142 143
        ])

        agent = create_structured_chat_agent(llm, tools, prompt)

文靖昊 committed
144
        self.agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=verbose, return_intermediate_steps=True,handle_parsing_errors=True)
145 146 147 148 149 150 151 152 153 154
        
    def exec(self, prompt_args: dict = {}, stream: bool = False):
        return self.agent_executor.invoke(input=prompt_args)

    def stream(self, prompt_args: dict = {}):
        for step in self.agent_executor.stream(prompt_args):
            yield step

def new_rate_agent(llm, verbose: bool = False,**args):

tinywell committed
155 156 157 158 159
    if args['tool_base_url']:
        tool_base_url = args['tool_base_url']
    else:
        tool_base_url = const_base_url

160
    tools = [
tinywell committed
161 162 163
        RegionRateTool(base_url=tool_base_url),
        RankingRateTool(base_url=tool_base_url),
        MonitorPointTool(base_url=tool_base_url)
164 165 166 167 168 169 170 171 172
    ]

    # 使用 LangChain 的工具调用代理
    agent = RateAgentV2(llm=llm, tools=tools, verbose=verbose, **args)
    return agent



    
173
class RateAgentV3:
174
    def __init__(self, llm, tool_base_url: str,version: str = "v1"):
175 176 177
        tools = [
            RegionRateTool(base_url=tool_base_url),
            RankingRateTool(base_url=tool_base_url),
178 179
            MonitorPointTool(base_url=tool_base_url),
            WarningTool(base_url=tool_base_url)
180 181 182 183 184
        ]
        self.picker = ToolPicker(llm, tools)
        tools_dict = {}
        for t in tools:
            tools_dict[t.name] = t
185 186
        self.runner = ToolRunner(llm, tools_dict,version)
        self.version = version
187 188 189

    def run(self, input: str):
        picker_result = self.picker.pick(input)
190
        res = self.runner.run(input, picker_result["tool"], picker_result["params"])
191 192 193 194 195
        if self.version != "v1": 
            output = res['output']
        else: # v1 版本, 使用表格单独展示数据
            output = f"相关数据如下:\n{res['table']}\n\n{res['output']}"
            
196 197
        return {
            "input": input,
198
            "output": output
199
        }
200