Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
f9a79295
Commit
f9a79295
authored
Jul 15, 2024
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
chore: Update agent tool descriptions to support chart tool
parent
8751d2b2
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
254 additions
and
10 deletions
+254
-10
tool_chart.py
src/agent/tool_chart.py
+81
-0
tool_divisions.py
src/agent/tool_divisions.py
+1
-1
prompts.py
src/config/prompts.py
+45
-1
agent.py
src/server/agent.py
+105
-2
agent_test.py
test/agent_test.py
+22
-6
No files found.
src/agent/tool_chart.py
0 → 100644
View file @
f9a79295
from
typing
import
Type
from
pydantic
import
BaseModel
,
Field
from
langchain_core.tools
import
BaseTool
class
ChartArgs
(
BaseModel
):
name
:
str
=
Field
(
...
,
description
=
"图表名称,用于显示在图表上方"
)
chart_type
:
str
=
Field
(
...
,
description
=
"图表类型,如 line, bar, scatter, pie 等"
)
x
:
list
=
Field
(
...
,
description
=
"x 轴数据,列表形式"
)
y
:
list
=
Field
(
...
,
description
=
"y 轴数据,列表形式"
)
x_label
:
str
=
Field
(
...
,
description
=
"x 轴标签"
)
y_label
:
str
=
Field
(
...
,
description
=
"y 轴标签"
)
class
Chart
(
BaseTool
):
name
=
"chart"
description
=
"组装生成图表的中间数据"
args_schema
:
Type
[
BaseModel
]
=
ChartArgs
def
_run
(
self
,
name
:
str
,
x
:
list
,
y
:
list
,
x_label
:
str
,
y_label
:
str
)
->
str
:
"""Use the tool."""
result
=
{
"name"
:
name
,
"x"
:
x
,
"y"
:
y
,
"x_label"
:
x_label
,
"y_label"
:
y_label
}
return
result
# 生成图表
def
chart_image
(
chart_data
):
"""
生成图表
Args:
chart_data: dict 图表数据
{
"name": str, 图表名称
"chart_type": str, 图表类型,如 line, bar, scatter, pie 等
"x": list, x 轴数据,列表形式
"y": list, y 轴数据,列表形式
"x_label": str, x 轴标签
"y_label": str, y 轴标签
}
Returns:
PIL Image 图表图片
"""
import
matplotlib.pyplot
as
plt
plt
.
figure
(
figsize
=
(
10
,
6
))
match
chart_data
[
"chart_type"
]:
case
"line"
:
plt
.
plot
(
chart_data
[
"x"
],
chart_data
[
"y"
])
case
"bar"
:
plt
.
bar
(
chart_data
[
"x"
],
chart_data
[
"y"
])
case
"scatter"
:
plt
.
scatter
(
chart_data
[
"x"
],
chart_data
[
"y"
])
case
"pie"
:
plt
.
pie
(
chart_data
[
"y"
],
labels
=
chart_data
[
"x"
],
autopct
=
"
%1.1
f
%%
"
)
case
_
:
raise
ValueError
(
"Invalid chart type"
)
plt
.
xlabel
(
chart_data
[
"x_label"
])
plt
.
ylabel
(
chart_data
[
"y_label"
])
plt
.
title
(
chart_data
[
"name"
])
# plt.show()
from
io
import
BytesIO
buf
=
BytesIO
()
plt
.
savefig
(
buf
,
format
=
"png"
)
from
PIL
import
Image
image
=
Image
.
open
(
buf
)
# image.show()
return
image
\ No newline at end of file
src/agent/tool_divisions.py
View file @
f9a79295
...
...
@@ -132,7 +132,7 @@ class AdministrativeDivisionArgs(BaseModel):
class
AdministrativeDivision
(
BaseTool
):
name
=
"administrative_division"
description
=
"根据
输入补全
行政区划信息,明确具体的省、市、县信息。比如输入县,补全所属省市,输入市则补全省级以及下辖所有县区"
description
=
"根据
用户提问中涉及到的地区信息补全其
行政区划信息,明确具体的省、市、县信息。比如输入县,补全所属省市,输入市则补全省级以及下辖所有县区"
args_schema
:
Type
[
BaseModel
]
=
AdministrativeDivisionArgs
def
_run
(
self
,
input_text
:
str
)
->
str
:
...
...
src/config/prompts.py
View file @
f9a79295
...
...
@@ -37,7 +37,49 @@ Action:
}}
```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果合适,请直接回复。格式为 Action:```$JSON_BLOB```然后 Observation
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。格式为 Action:```$JSON_BLOB```然后 Observation
"""
PROMPT_AGENT_CAHRT_SYS
=
"""请尽量帮助人类并准确回答问题。您可以使用以下工具:
{tools}
使用 JSON 对象指定工具,提供一个 action 键(工具名称)和一个 action_input 键(工具输入), 以及 action_cache 键(必要时存储工具中间结果) 。
有效的 "action" 值: "Final Answer" 或 {tool_names}
每个 $JSON_BLOB 只提供一个操作,如下所示:
```
{{
"action": $TOOL_NAME,
"action_input": $INPUT,
"action_cache": $CACHE
}}
```
按照以下格式:
Question: 输入要回答的问题
Thought: 考虑前后步骤
Action:
```
$JSON_BLOB
```
Observation: 操作结果
...(重复 Thought/Action/Observation N 次)
Thought: 我知道如何回复
Action:
```
{{
"action": "Final Answer",
"action_input": "最终回复给人类",
"action_cache": "中间结果缓存"
}}
```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。如果有生成图表的需求,请使用图表生成工具 {chart_tool},并将结果存储到 $CACHE 中。
格式为 Action:```$JSON_BLOB```然后 Observation
"""
PROMPT_AGENT_HUMAN
=
"""{input}
\n\n
{agent_scratchpad}
\n
(请注意,无论如何都要以 JSON 对象回复)"""
...
...
@@ -128,6 +170,8 @@ A: Laf 是一个云函数开发平台。
检索词: """
# 结合历史问答对话,生成新的提问,引导用户继续对话
PROMPT_QA_EXTEND_QUESTION
=
"""
作为一个问答助手,你的任务是结合历史记录,生成三个新的问题,引导用户继续对话。生成的问题要求与对话内容相关且指向对象清晰明确,并与“原问题语言相同”。例如:
...
...
src/server/agent.py
View file @
f9a79295
from
typing
import
Any
,
List
from
typing
import
Any
,
List
,
Sequence
,
Union
from
langchain_core.prompts
import
PromptTemplate
from
langchain.agents
import
AgentExecutor
,
create_tool_calling_agent
,
create_structured_chat_agent
from
langchain.tools
import
BaseTool
from
langgraph.prebuilt
import
create_react_agent
from
langchain_core.language_models
import
BaseLanguageModel
from
langchain_core.prompts.chat
import
ChatPromptTemplate
from
langchain_core.runnables
import
Runnable
,
RunnablePassthrough
from
langchain_core.output_parsers.json
import
parse_json_markdown
from
langchain_core.agents
import
AgentAction
,
AgentFinish
from
langchain_core.exceptions
import
OutputParserException
from
langchain.tools.render
import
ToolsRenderer
,
render_text_description_and_args
from
langchain.agents.format_scratchpad
import
format_log_to_str
from
langchain.agents.agent
import
AgentOutputParser
class
ChartAgentOutputParser
(
AgentOutputParser
):
"""Parses tool invocations and final answers in JSON format.
Expects output to be in one of two formats.
If the output signals that an action should be taken,
should be in the below format. This will result in an AgentAction
being returned.
```
{
"action": "search",
"action_input": "2+2",
"action_cache": ""
}
```
If the output signals that a final answer should be given,
should be in the below format. This will result in an AgentFinish
being returned.
```
{
"action": "Final Answer",
"action_input": "4",
"action_cache": ""
}
```
"""
def
parse
(
self
,
text
:
str
)
->
Union
[
AgentAction
,
AgentFinish
]:
try
:
response
=
parse_json_markdown
(
text
)
if
isinstance
(
response
,
list
):
# gpt turbo frequently ignores the directive to emit a single action
# logger.warning("Got multiple action responses: %s", response)
response
=
response
[
0
]
if
response
[
"action"
]
==
"Final Answer"
:
if
"action_cache"
in
response
:
return
AgentFinish
({
"output"
:
response
[
"action_input"
],
"cache"
:
response
[
"action_cache"
]},
text
)
else
:
return
AgentFinish
({
"output"
:
response
[
"action_input"
]},
text
)
else
:
return
AgentAction
(
response
[
"action"
],
response
.
get
(
"action_input"
,
{}),
text
)
except
Exception
as
e
:
raise
OutputParserException
(
f
"Could not parse LLM output: {text}"
)
from
e
@property
def
_type
(
self
)
->
str
:
return
"chart-agent"
def
create_chart_agent
(
llm
:
BaseLanguageModel
,
tools
:
Sequence
[
BaseTool
],
prompt
:
ChatPromptTemplate
,
chart_tool
:
str
,
tools_renderer
:
ToolsRenderer
=
render_text_description_and_args
,
*
,
stop_sequence
:
Union
[
bool
,
List
[
str
]]
=
True
,
)
->
Runnable
:
"""Create an agent aimed at supporting chart tools with multiple inputs.
"""
missing_vars
=
{
"tools"
,
"tool_names"
,
"agent_scratchpad"
,
"chart_tool"
}
.
difference
(
prompt
.
input_variables
+
list
(
prompt
.
partial_variables
)
)
if
missing_vars
:
raise
ValueError
(
f
"Prompt missing required variables: {missing_vars}"
)
prompt
=
prompt
.
partial
(
tools
=
tools_renderer
(
list
(
tools
)),
tool_names
=
", "
.
join
([
t
.
name
for
t
in
tools
]),
chart_tool
=
chart_tool
,
)
if
stop_sequence
:
stop
=
[
"
\n
Observation"
]
if
stop_sequence
is
True
else
stop_sequence
llm_with_stop
=
llm
.
bind
(
stop
=
stop
)
else
:
llm_with_stop
=
llm
agent
=
(
RunnablePassthrough
.
assign
(
agent_scratchpad
=
lambda
x
:
format_log_to_str
(
x
[
"intermediate_steps"
]),
)
|
prompt
|
llm_with_stop
|
ChartAgentOutputParser
()
)
return
agent
class
Agent
:
def
__init__
(
self
,
llm
,
tools
:
List
[
BaseTool
],
prompt
:
PromptTemplate
=
None
,
verbose
:
bool
=
False
):
self
.
llm
=
llm
...
...
@@ -15,7 +118,7 @@ class Agent:
agent
=
create_react_agent
(
llm
,
tools
,
debug
=
verbose
)
self
.
agent_executor
=
agent
else
:
agent
=
create_
structured_cha
t_agent
(
llm
,
tools
,
prompt
)
agent
=
create_
char
t_agent
(
llm
,
tools
,
prompt
)
self
.
agent_executor
=
AgentExecutor
(
agent
=
agent
,
tools
=
tools
,
verbose
=
verbose
)
def
exec
(
self
,
prompt_args
:
dict
=
{},
stream
:
bool
=
False
):
...
...
test/agent_test.py
View file @
f9a79295
...
...
@@ -15,8 +15,8 @@ from langchain.agents import AgentExecutor, create_tool_calling_agent,create_str
from
pydantic
import
BaseModel
,
Field
from
src.server.agent
import
Agent
from
src.config.prompts
import
PROMPT_AGENT_SYS
,
PROMPT_AGENT_HUMAN
from
src.server.agent
import
Agent
,
create_chart_agent
from
src.config.prompts
import
PROMPT_AGENT_SYS
,
PROMPT_AGENT_HUMAN
,
PROMPT_AGENT_CAHRT_SYS
from
src.agent.tool_divisions
import
AdministrativeDivision
,
CountryInfo
...
...
@@ -51,10 +51,10 @@ llm = ChatOpenAI(
)
# prompt = hub.pull("hwchase17/openai-functions-agent")
input_variables
=
[
'agent_scratchpad'
,
'input'
,
'tool_names'
,
'tools'
]
input_variables
=
[
'agent_scratchpad'
,
'input'
,
'tool_names'
,
'tools'
,
"chart_tool"
]
input_types
=
{
'chat_history'
:
List
[
Union
[
langchain_core
.
messages
.
ai
.
AIMessage
,
langchain_core
.
messages
.
human
.
HumanMessage
,
langchain_core
.
messages
.
chat
.
ChatMessage
,
langchain_core
.
messages
.
system
.
SystemMessage
,
langchain_core
.
messages
.
function
.
FunctionMessage
,
langchain_core
.
messages
.
tool
.
ToolMessage
]]}
messages
=
[
SystemMessagePromptTemplate
(
prompt
=
PromptTemplate
(
input_variables
=
[
'tool_names'
,
'tools'
],
template
=
PROMPT_AGEN
T_SYS
)),
SystemMessagePromptTemplate
(
prompt
=
PromptTemplate
(
input_variables
=
[
'tool_names'
,
'tools'
,
"chart_tool"
],
template
=
PROMPT_AGENT_CAHR
T_SYS
)),
MessagesPlaceholder
(
variable_name
=
'chat_history'
,
optional
=
True
),
HumanMessagePromptTemplate
(
prompt
=
PromptTemplate
(
input_variables
=
[
'agent_scratchpad'
,
'input'
],
template
=
PROMPT_AGENT_HUMAN
))
]
...
...
@@ -68,6 +68,7 @@ prompt = ChatPromptTemplate(
def
test_add
():
tools
=
[
Calc
()]
agent
=
Agent
(
llm
=
llm
,
tools
=
tools
,
prompt
=
prompt
,
verbose
=
True
)
agent
=
create_chart_agent
(
llm
,
tools
,
prompt
,
chart_tool
=
"chart"
)
res
=
agent
.
exec
(
prompt_args
=
{
"input"
:
"what is 1 + 1?"
})
...
...
@@ -88,5 +89,20 @@ def test_agent_division():
res
=
agent
.
exec
(
prompt_args
=
{
"input"
:
"我想知道陇南市西和县和文县的降雨量谁的多"
})
print
(
res
)
def
test_chart_tool
():
from
src.agent.tool_chart
import
chart_image
x
=
[
1
,
2
,
3
,
4
,
5
]
y
=
[
1
,
4
,
9
,
16
,
25
]
chart_data
=
{
"name"
:
"test"
,
"chart_type"
:
"bar"
,
"x"
:
x
,
"y"
:
y
,
"x_label"
:
"x axis"
,
"y_label"
:
"y axis"
}
chart_image
(
chart_data
)
if
__name__
==
"__main__"
:
test_agent_division
()
\ No newline at end of file
# test_agent_division()
test_chart_tool
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment