Commit f9a79295 by tinywell

chore: Update agent tool descriptions to support chart tool

parent 8751d2b2
from typing import Type
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool
class ChartArgs(BaseModel):
name: str = Field(..., description="图表名称,用于显示在图表上方")
chart_type: str = Field(..., description="图表类型,如 line, bar, scatter, pie 等")
x: list = Field(..., description="x 轴数据,列表形式")
y: list = Field(..., description="y 轴数据,列表形式")
x_label: str = Field(..., description="x 轴标签")
y_label: str = Field(..., description="y 轴标签")
class Chart(BaseTool):
name = "chart"
description = "组装生成图表的中间数据"
args_schema: Type[BaseModel] = ChartArgs
def _run(
self, name: str, x: list, y: list, x_label: str, y_label: str
) -> str:
"""Use the tool."""
result = {
"name": name,
"x": x,
"y": y,
"x_label": x_label,
"y_label": y_label
}
return result
# 生成图表
def chart_image(chart_data):
"""
生成图表
Args:
chart_data: dict 图表数据
{
"name": str, 图表名称
"chart_type": str, 图表类型,如 line, bar, scatter, pie 等
"x": list, x 轴数据,列表形式
"y": list, y 轴数据,列表形式
"x_label": str, x 轴标签
"y_label": str, y 轴标签
}
Returns:
PIL Image 图表图片
"""
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
match chart_data["chart_type"]:
case "line":
plt.plot(chart_data["x"], chart_data["y"])
case "bar":
plt.bar(chart_data["x"], chart_data["y"])
case "scatter":
plt.scatter(chart_data["x"], chart_data["y"])
case "pie":
plt.pie(chart_data["y"], labels=chart_data["x"], autopct="%1.1f%%")
case _:
raise ValueError("Invalid chart type")
plt.xlabel(chart_data["x_label"])
plt.ylabel(chart_data["y_label"])
plt.title(chart_data["name"])
# plt.show()
from io import BytesIO
buf = BytesIO()
plt.savefig(buf, format="png")
from PIL import Image
image = Image.open(buf)
# image.show()
return image
\ No newline at end of file
...@@ -132,7 +132,7 @@ class AdministrativeDivisionArgs(BaseModel): ...@@ -132,7 +132,7 @@ class AdministrativeDivisionArgs(BaseModel):
class AdministrativeDivision(BaseTool): class AdministrativeDivision(BaseTool):
name = "administrative_division" name = "administrative_division"
description = "根据输入补全行政区划信息,明确具体的省、市、县信息。比如输入县,补全所属省市,输入市则补全省级以及下辖所有县区" description = "根据用户提问中涉及到的地区信息补全其行政区划信息,明确具体的省、市、县信息。比如输入县,补全所属省市,输入市则补全省级以及下辖所有县区"
args_schema: Type[BaseModel] = AdministrativeDivisionArgs args_schema: Type[BaseModel] = AdministrativeDivisionArgs
def _run(self, input_text: str) -> str: def _run(self, input_text: str) -> str:
......
...@@ -37,7 +37,49 @@ Action: ...@@ -37,7 +37,49 @@ Action:
}} }}
``` ```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果合适,请直接回复。格式为 Action:```$JSON_BLOB```然后 Observation 开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。格式为 Action:```$JSON_BLOB```然后 Observation
"""
PROMPT_AGENT_CAHRT_SYS = """请尽量帮助人类并准确回答问题。您可以使用以下工具:
{tools}
使用 JSON 对象指定工具,提供一个 action 键(工具名称)和一个 action_input 键(工具输入), 以及 action_cache 键(必要时存储工具中间结果) 。
有效的 "action" 值: "Final Answer" 或 {tool_names}
每个 $JSON_BLOB 只提供一个操作,如下所示:
```
{{
"action": $TOOL_NAME,
"action_input": $INPUT,
"action_cache": $CACHE
}}
```
按照以下格式:
Question: 输入要回答的问题
Thought: 考虑前后步骤
Action:
```
$JSON_BLOB
```
Observation: 操作结果
...(重复 Thought/Action/Observation N 次)
Thought: 我知道如何回复
Action:
```
{{
"action": "Final Answer",
"action_input": "最终回复给人类",
"action_cache": "中间结果缓存"
}}
```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。如果有生成图表的需求,请使用图表生成工具 {chart_tool},并将结果存储到 $CACHE 中。
格式为 Action:```$JSON_BLOB```然后 Observation
""" """
PROMPT_AGENT_HUMAN = """{input}\n\n{agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复)""" PROMPT_AGENT_HUMAN = """{input}\n\n{agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复)"""
...@@ -128,6 +170,8 @@ A: Laf 是一个云函数开发平台。 ...@@ -128,6 +170,8 @@ A: Laf 是一个云函数开发平台。
检索词: """ 检索词: """
# 结合历史问答对话,生成新的提问,引导用户继续对话 # 结合历史问答对话,生成新的提问,引导用户继续对话
PROMPT_QA_EXTEND_QUESTION = """ PROMPT_QA_EXTEND_QUESTION = """
作为一个问答助手,你的任务是结合历史记录,生成三个新的问题,引导用户继续对话。生成的问题要求与对话内容相关且指向对象清晰明确,并与“原问题语言相同”。例如: 作为一个问答助手,你的任务是结合历史记录,生成三个新的问题,引导用户继续对话。生成的问题要求与对话内容相关且指向对象清晰明确,并与“原问题语言相同”。例如:
......
from typing import Any, List from typing import Any, List, Sequence, Union
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain.agents import AgentExecutor, create_tool_calling_agent,create_structured_chat_agent from langchain.agents import AgentExecutor, create_tool_calling_agent,create_structured_chat_agent
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.output_parsers.json import parse_json_markdown
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langchain.tools.render import ToolsRenderer, render_text_description_and_args
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.agent import AgentOutputParser
class ChartAgentOutputParser(AgentOutputParser):
"""Parses tool invocations and final answers in JSON format.
Expects output to be in one of two formats.
If the output signals that an action should be taken,
should be in the below format. This will result in an AgentAction
being returned.
```
{
"action": "search",
"action_input": "2+2",
"action_cache": ""
}
```
If the output signals that a final answer should be given,
should be in the below format. This will result in an AgentFinish
being returned.
```
{
"action": "Final Answer",
"action_input": "4",
"action_cache": ""
}
```
"""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
try:
response = parse_json_markdown(text)
if isinstance(response, list):
# gpt turbo frequently ignores the directive to emit a single action
# logger.warning("Got multiple action responses: %s", response)
response = response[0]
if response["action"] == "Final Answer":
if "action_cache" in response:
return AgentFinish({"output": response["action_input"],"cache":response["action_cache"]}, text)
else:
return AgentFinish({"output": response["action_input"]}, text)
else:
return AgentAction(
response["action"], response.get("action_input", {}), text
)
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e
@property
def _type(self) -> str:
return "chart-agent"
def create_chart_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
chart_tool: str,
tools_renderer: ToolsRenderer = render_text_description_and_args,
*,
stop_sequence: Union[bool, List[str]] = True,
) -> Runnable:
"""Create an agent aimed at supporting chart tools with multiple inputs.
"""
missing_vars = {"tools", "tool_names", "agent_scratchpad","chart_tool"}.difference(
prompt.input_variables + list(prompt.partial_variables)
)
if missing_vars:
raise ValueError(f"Prompt missing required variables: {missing_vars}")
prompt = prompt.partial(
tools=tools_renderer(list(tools)),
tool_names=", ".join([t.name for t in tools]),
chart_tool=chart_tool,
)
if stop_sequence:
stop = ["\nObservation"] if stop_sequence is True else stop_sequence
llm_with_stop = llm.bind(stop=stop)
else:
llm_with_stop = llm
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),
)
| prompt
| llm_with_stop
| ChartAgentOutputParser()
)
return agent
class Agent: class Agent:
def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False): def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False):
self.llm = llm self.llm = llm
...@@ -15,7 +118,7 @@ class Agent: ...@@ -15,7 +118,7 @@ class Agent:
agent = create_react_agent(llm, tools,debug=verbose) agent = create_react_agent(llm, tools,debug=verbose)
self.agent_executor = agent self.agent_executor = agent
else: else:
agent = create_structured_chat_agent(llm, tools, prompt) agent = create_chart_agent(llm, tools, prompt)
self.agent_executor = AgentExecutor(agent=agent,tools=tools,verbose=verbose) self.agent_executor = AgentExecutor(agent=agent,tools=tools,verbose=verbose)
def exec(self, prompt_args: dict = {}, stream: bool = False): def exec(self, prompt_args: dict = {}, stream: bool = False):
......
...@@ -15,8 +15,8 @@ from langchain.agents import AgentExecutor, create_tool_calling_agent,create_str ...@@ -15,8 +15,8 @@ from langchain.agents import AgentExecutor, create_tool_calling_agent,create_str
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from src.server.agent import Agent from src.server.agent import Agent, create_chart_agent
from src.config.prompts import PROMPT_AGENT_SYS, PROMPT_AGENT_HUMAN from src.config.prompts import PROMPT_AGENT_SYS, PROMPT_AGENT_HUMAN, PROMPT_AGENT_CAHRT_SYS
from src.agent.tool_divisions import AdministrativeDivision, CountryInfo from src.agent.tool_divisions import AdministrativeDivision, CountryInfo
...@@ -51,10 +51,10 @@ llm = ChatOpenAI( ...@@ -51,10 +51,10 @@ llm = ChatOpenAI(
) )
# prompt = hub.pull("hwchase17/openai-functions-agent") # prompt = hub.pull("hwchase17/openai-functions-agent")
input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools'] input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools', "chart_tool"]
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]} input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}
messages=[ messages=[
SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools'], template=PROMPT_AGENT_SYS)), SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools', "chart_tool"], template=PROMPT_AGENT_CAHRT_SYS)),
MessagesPlaceholder(variable_name='chat_history', optional=True), MessagesPlaceholder(variable_name='chat_history', optional=True),
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template=PROMPT_AGENT_HUMAN)) HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template=PROMPT_AGENT_HUMAN))
] ]
...@@ -68,6 +68,7 @@ prompt = ChatPromptTemplate( ...@@ -68,6 +68,7 @@ prompt = ChatPromptTemplate(
def test_add(): def test_add():
tools = [Calc()] tools = [Calc()]
agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True) agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True)
agent = create_chart_agent(llm, tools, prompt,chart_tool="chart")
res = agent.exec(prompt_args={"input": "what is 1 + 1?"}) res = agent.exec(prompt_args={"input": "what is 1 + 1?"})
...@@ -88,5 +89,20 @@ def test_agent_division(): ...@@ -88,5 +89,20 @@ def test_agent_division():
res = agent.exec(prompt_args={"input": "我想知道陇南市西和县和文县的降雨量谁的多"}) res = agent.exec(prompt_args={"input": "我想知道陇南市西和县和文县的降雨量谁的多"})
print(res) print(res)
def test_chart_tool():
from src.agent.tool_chart import chart_image
x = [1,2,3,4,5]
y = [1,4,9,16,25]
chart_data = {
"name": "test",
"chart_type": "bar",
"x": x,
"y": y,
"x_label": "x axis",
"y_label": "y axis"
}
chart_image(chart_data)
if __name__ == "__main__": if __name__ == "__main__":
test_agent_division() # test_agent_division()
\ No newline at end of file test_chart_tool()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment