Commit 376be994 by tinywell

表格生成两种方式:大模型根据问题生成;工具基于原数据生成,大模型只回答问题;

parent 886f28bf
......@@ -79,7 +79,7 @@ class AgentManager:
verbose=True
)
self.agent = new_rate_agent(self.llm, verbose=True, tool_base_url=tool_base_url)
self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url)
self.rate_agent = RateAgentV3(self.llm, tool_base_url=tool_base_url,version="v0")
self.router_llm = new_router_llm(self.llm)
self.re_rewriter_llm = new_re_rewriter_llm(self.llm)
......
......@@ -170,7 +170,7 @@ def new_rate_agent(llm, verbose: bool = False,**args):
class RateAgentV3:
def __init__(self, llm, tool_base_url: str):
def __init__(self, llm, tool_base_url: str,version: str = "v1"):
tools = [
RegionRateTool(base_url=tool_base_url),
RankingRateTool(base_url=tool_base_url),
......@@ -180,12 +180,17 @@ class RateAgentV3:
tools_dict = {}
for t in tools:
tools_dict[t.name] = t
self.runner = ToolRunner(llm, tools_dict)
self.runner = ToolRunner(llm, tools_dict,version)
self.version = version
def run(self, input: str):
picker_result = self.picker.pick(input)
res = self.runner.run(input, picker_result["tool"], picker_result["params"])
output = f"相关数据如下:\n{res['table']}\n\n{res['output']}"
if self.version != "v1":
output = res['output']
else: # v1 版本, 使用表格单独展示数据
output = f"相关数据如下:\n{res['table']}\n\n{res['output']}"
return {
"input": input,
"output": output
......
......@@ -62,7 +62,9 @@ class ToolPicker:
self.logger.error(f"Error picking tool: {str(e)}")
raise
RUNNER_SYSTEM_PROMPT = """
RUNNER_SYSTEM_PROMPT_V1 = """
你是一个擅长根据工具执行结果回答用户问题的助手。
工作流程:
......@@ -74,6 +76,21 @@ RUNNER_SYSTEM_PROMPT = """
- 不要提到工具调用
"""
# "这一版的问答效果相对来说比较好" -- wukai
RUNNER_SYSTEM_PROMPT_V0 = """
你是一个擅长根据工具执行结果回答用户问题的助手。
工作流程:
1. 分析用户的问题
2. 根据用户的问题,解读工具执行结果,进行简要的分析说明
3. 返回用户问题的答案
请遵循以下规则:
- 工具执行结果中的数据必须使用 markdown 表格展示
- 确保数据的完整性, 不要遗漏数据
- 表格中的数据只能来源于工具执行结果
"""
RUNNER_HUMAN_PROMPT = """
用户的问题是:{input}
......@@ -82,13 +99,16 @@ RUNNER_HUMAN_PROMPT = """
class ToolRunner:
def __init__(self, llm, tools: Dict[str, BaseTool]):
def __init__(self, llm, tools: Dict[str, BaseTool],version: str = "v1"):
self.tools = tools
self.llm = llm
self.logger = get_logger("ToolRunner")
self.version = version
system_prompt = RUNNER_SYSTEM_PROMPT_V1
if version != "v1":
system_prompt = RUNNER_SYSTEM_PROMPT_V0
prompt = ChatPromptTemplate.from_messages([
SystemMessagePromptTemplate.from_template(RUNNER_SYSTEM_PROMPT),
SystemMessagePromptTemplate.from_template(system_prompt),
HumanMessagePromptTemplate.from_template(RUNNER_HUMAN_PROMPT)
])
self.chain = prompt | self.llm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment