Commit cc83b288 by 文靖昊

多个agent tool结合

parent 47159811
......@@ -15,6 +15,7 @@ from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc
from langchain_core.documents import Document
import json
from src.agent.tool_divisions import AdministrativeDivision
from src.config.consts import (
RERANK_MODEL_PATH,
CHAT_DB_USER,
......@@ -1555,12 +1556,12 @@ def jieba_split(text: str) -> str:
class IssuanceArgs(BaseModel):
question: str = Field(description="对话问题")
history: list = Field(description="历史对话记录")
history: str = Field(description="历史对话记录")
class RAGQuery(BaseTool):
name = "rag_query"
description = """Query the geological information of corresponding provinces, cities, and counties. Users can query geological information related to specific provinces, cities, and counties"""
description = """你是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取单个区(县)的水文气象地质等相关信息,当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并获取具体的区(县)的水文气象地质等相关信息"""
args_schema: Type[BaseModel] = IssuanceArgs
rerank: Any # 替换 Any 为适当的类型
rerank_model: Any # 替换 Any 为适当的类型
......@@ -1581,7 +1582,7 @@ class RAGQuery(BaseTool):
def get_similarity_with_ext_origin(self, _ext):
return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db)
def _run(self, question: str, history: list) -> str:
def _run(self, question: str, history: str) -> str:
split_str = jieba_split(question)
split_list = []
for l in split_str:
......@@ -1593,18 +1594,15 @@ class RAGQuery(BaseTool):
d = Document(page_content=a[0], metadata=json.loads(a[1]))
split_docs.append(d)
print(split_docs)
result = self.rerank.extend_query(question, history)
result = self.rerank.extend_query_with_str(question, history)
matches = re.findall(r'"([^"]+)"', result.content)
if len(matches) > 3:
matches = matches[:3]
print(matches)
prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
similarity = self.get_similarity_with_ext_origin(matches)
# cur_similarity = similarity.get_rerank(self.rerank_model)
cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs)
cur_question = self.prompt.format(history=prompt, context=cur_similarity, question=question)
cur_question = self.prompt.format(history=history, context=cur_similarity, question=question)
return cur_question
......@@ -1631,7 +1629,7 @@ k_db.connect()
tools = [RAGQuery(vecstore_faiss,ext,PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template),_db=TxtDoc(k_db))]
tools = [AdministrativeDivision(),RAGQuery(vecstore_faiss,ext,PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template),_db=TxtDoc(k_db))]
input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools']
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}
metadata={'lc_hub_owner': 'hwchase17', 'lc_hub_repo': 'structured-chat-agent', 'lc_hub_commit_hash': 'ea510f70a5872eb0f41a4e3b7bb004d5711dc127adee08329c664c6c8be5f13c'}
......@@ -1651,7 +1649,17 @@ prompt = ChatPromptTemplate(
agent = create_structured_chat_agent(llm=base_llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools,verbose=True,handle_parsing_errors=True)
res = agent_executor.invoke({"input":"大通县"})
history = []
h1 = []
h1.append("大通县年降雨量")
h1.append("大通县年雨量平均为30ml")
history.append(h1)
prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
print(prompt)
res = agent_executor.invoke({"input":"以下历史对话记录: "+prompt+"以下是问题:"+"西宁市年平均降雨量"})
print("====== result: ======")
print(res)
......@@ -120,6 +120,20 @@ class QAExt:
history += f"Q: {msg[0]}\nA: {msg[1]}\n"
return self.query_extend.invoke(input={"histories":messages, "query":question})
def extend_query_with_str(self, question, messages):
"""
question: str
messages: list of tuple (str,str)
eg:
[
("Q1","A1"),
("Q2","A2"),
...
]
"""
return self.query_extend.invoke(input={"histories": messages, "query": question})
class ChatExtend:
def __init__(self, llm) -> None:
self.llm = llm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment