Commit f51470c9 by tinywell

rrf 测试

parent f9a79295
......@@ -40,7 +40,7 @@ Action:
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。格式为 Action:```$JSON_BLOB```然后 Observation
"""
PROMPT_AGENT_CAHRT_SYS = """请尽量帮助人类并准确回答问题。您可以使用以下工具:
PROMPT_AGENT_CHART_SYS = """请尽量帮助人类并准确回答问题。您可以使用以下工具:
{tools}
......@@ -73,11 +73,11 @@ Action:
{{
"action": "Final Answer",
"action_input": "最终回复给人类",
"action_cache": "中间结果缓存"
"action_cache": "图表工具结果缓存"
}}
```
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具。如果你知道答案,请直接回复。如果有生成图表的需求,请使用图表生成工具 {chart_tool},并将结果存储到 $CACHE 中。
格式为 Action:```$JSON_BLOB```然后 Observation
开始!始终以有效的单个操作的 JSON 对象回复。如有必要,请使用工具,如果你知道答案,请直接回复。如果用户有生成图表的需求,请使用图表生成工具 {chart_tool},并将结果记录为 $CACHE 中。
你的回复格式为 Action:```$JSON_BLOB```然后 Observation
"""
......
......@@ -109,7 +109,7 @@ def create_chart_agent(
return agent
class Agent:
def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False):
def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args):
self.llm = llm
self.tools = tools
self.prompt = prompt
......@@ -118,7 +118,7 @@ class Agent:
agent = create_react_agent(llm, tools,debug=verbose)
self.agent_executor = agent
else:
agent = create_chart_agent(llm, tools, prompt)
agent = create_chart_agent(llm, tools, prompt, **args)
self.agent_executor = AgentExecutor(agent=agent,tools=tools,verbose=verbose)
def exec(self, prompt_args: dict = {}, stream: bool = False):
......
......@@ -16,7 +16,7 @@ from langchain.agents import AgentExecutor, create_tool_calling_agent,create_str
from pydantic import BaseModel, Field
from src.server.agent import Agent, create_chart_agent
from src.config.prompts import PROMPT_AGENT_SYS, PROMPT_AGENT_HUMAN, PROMPT_AGENT_CAHRT_SYS
from src.config.prompts import PROMPT_AGENT_SYS, PROMPT_AGENT_HUMAN, PROMPT_AGENT_CHART_SYS
from src.agent.tool_divisions import AdministrativeDivision, CountryInfo
......@@ -54,7 +54,7 @@ llm = ChatOpenAI(
input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools', "chart_tool"]
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}
messages=[
SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools', "chart_tool"], template=PROMPT_AGENT_CAHRT_SYS)),
SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools', "chart_tool"], template=PROMPT_AGENT_CHART_SYS)),
MessagesPlaceholder(variable_name='chat_history', optional=True),
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template=PROMPT_AGENT_HUMAN))
]
......@@ -101,8 +101,18 @@ def test_chart_tool():
"x_label": "x axis",
"y_label": "y axis"
}
chart_image(chart_data)
img = chart_image(chart_data)
img.show()
def test_agent_chart():
from src.agent.tool_chart import Chart
tools = [Chart()]
agent = Agent(llm=llm, tools=tools, prompt=prompt, verbose=True, chart_tool="chart")
res = agent.exec(prompt_args={"input": "请告诉我海拔前十的高山有哪些"})
print(res)
if __name__ == "__main__":
# test_agent_division()
test_chart_tool()
\ No newline at end of file
# test_chart_tool()
test_agent_chart()
\ No newline at end of file
......@@ -36,5 +36,25 @@ def test_chatextend():
result = ext.new_questions(messages=message)
print(result.content)
def test_rrf():
from langchain_core.documents import Document
from src.server.rerank import reciprocal_rank_fusion
docs = [
Document(page_content="我需要查找海拔最高的十座高山的信息。这可能需要从一个数据库或在线资源中获取数据。我将使用一个假设的数据库来获取这些信息。",metadata={"font-size": 12, "page_number": 1}),
Document(page_content="2019年,中国成年男性的平均身高为170厘米,女性为160厘米。",metadata={"font-size": 12, "page_number": 2}),
Document(page_content="2020年,中国成年男性的平均身高为171厘米,女性为160厘米。",metadata={"font-size": 12, "page_number": 3}),
]
docs2 = [
Document(page_content="我们的 llm_engine 必须是一个可调用的函数,它接受一系列消息作为输入并返回文本。它还需要接受一个 stop_sequences 参数,该参数指示何时停止其生成。为了方便起见,我们直接使用软件包中提供的 HfEngine 类来获取一个调用我们的推理 API 的LLM引擎。",metadata={"font-size": 12, "page_number": 1}),
Document(page_content="由于我们将代理初始化为一个 ReactJsonAgent ,因此它会自动获得一个默认的系统提示,该提示告诉LLM引擎逐步处理并生成 JSON 块作为工具调用(您可以根据需要替换此提示模板)。",metadata={"font-size": 12, "page_number": 2}),
]
res = reciprocal_rank_fusion([(60,docs),(55,docs2)])
print(res)
if __name__ == "__main__":
test_chatextend()
\ No newline at end of file
# test_chatextend()
test_rrf()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment