Commit dd12b019 by 文靖昊

rag agent流程优化,输出相关文档

parent 75467574
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field
from typing import Type, Any
from typing import Type, Any,List
import re
from src.server.get_similarity import GetSimilarityWithExt
import time
from src.server.rerank import BgeRerank
from langchain_core.documents import Document
import json
......@@ -17,12 +17,12 @@ from src.config.consts import (
class IssuanceArgs(BaseModel):
question: str = Field(description="对话问题")
history: str = Field(description="历史对话记录")
location: list = Field(description="行政区划名称")
location: list = Field(description="question参数中的行政区划名称")
class RAGQuery(BaseTool):
name = "rag_query"
description = """这是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取单个区(县)的水文气象地质等相关信息,当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并一一获取具体的区(县)的水文气象地质等相关信息。这个知识库中信息并不全面,有可能缺失。这个工具生成的结果要求保存到 action_cache 中。键值为 'rag_query',值为工具输出。形式为 {"rag_query":$RESULT}"""
description = """这是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取单个区(县)的水文气象地质等相关信息,当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并一一获取具体的区(县)的水文气象地质等相关信息。这个知识库中信息并不全面,有可能缺失。这个工具生成的结果需要再调用rag_analysis这个工具来进行解析,每次调用完成之后,一定要调用rag_analysis去解析结果"""
args_schema: Type[BaseModel] = IssuanceArgs
rerank: Any # 替换 Any 为适当的类型
rerank_model: Any # 替换 Any 为适当的类型
......@@ -45,36 +45,68 @@ class RAGQuery(BaseTool):
def get_similarity_with_ext_origin(self, _ext):
return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db)
def _run(self, question: str, history: str,location:list) -> str:
def _run(self, question: str, history: str,location:list) :
print(location)
# split_str = jieba_split(question)
# split_list = []
# for l in split_str:
# split_list.append(l)
start = time.time()
answer = self.db.find_like_doc(location)
end = time.time()
print('find_like_doc time: %s Seconds' % (end - start))
print(len(answer) if answer else 0)
split_docs = []
for a in answer if answer else []:
d = Document(page_content=a[0], metadata=json.loads(a[1]))
split_docs.append(d)
print(len(split_docs))
# if len(split_docs)>100:
# split_docs= split_docs[:100]
# if len(split_docs)>10:
# split_docs= split_docs[:10]
start = time.time()
result = self.rerank.extend_query_with_str(question, history)
end = time.time()
print('extend_query_with_str time: %s Seconds' % (end - start))
print(result)
matches = re.findall(r'"([^"]+)"', result.content)
print(matches)
similarity = self.get_similarity_with_ext_origin(matches)
# cur_similarity = similarity.get_rerank(self.rerank_model)
# cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs)
cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs)
docs = similarity.get_rerank_docs()
# print(cur_similarity)
# # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question
# # cur_question = self.prompt.format(history=history, context=cur_similarity, question=question)
# cur_answer = self.llm_chain.run(context=cur_similarity, question=question)
cur_answer = self.llm_chain.run(context=cur_similarity, question=question)
# print(cur_answer)
# return cur_answer
return {"参考资料": docs, "原问题": question}
loc = location[0]
location = location[1:]
for i in location:
loc += (","+i)
return {"详细信息":cur_answer,"参考文档": cur_similarity}
class RAGAnalysisArgs(BaseModel):
question: str = Field(description="rag_query附带的问题")
doc: str = Field(description="rag_query获取的县级水文气象地质参考资料")
class RAGAnalysisQuery(BaseTool):
name = "rag_analysis"
description = """这是一个区(县)级的水文气象地质知识库解析工具,从rag_query查询到的资料,需要结合原始问题,用这个工具来解析出想要的答案,调用rag_query工具做查询之后一定要调用这个工具"""
args_schema: Type[BaseModel] = RAGAnalysisArgs
llm_chain: Any
def __init__(self,_llm_chain):
super().__init__()
self.llm_chain = _llm_chain
def _run(self, question: str, doc: list):
cur_answer = self.llm_chain.run(context=doc, question=question)
return cur_answer
\ No newline at end of file
......@@ -81,6 +81,7 @@ Action:
"""
PROMPT_AGENT_CHART_SYS_VARS = [ "tool_names", "tools"]
PROMPT_AGENT_SYS_VARS = [ "tool_names", "tools"]
PROMPT_AGENT_HUMAN = """{input}\n\n{agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复)"""
......
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.prompts import PROMPT_QUERY_EXTEND,PROMPT_QA_EXTEND_QUESTION
from src.server.rerank import BgeRerank,reciprocal_rank_fusion
import time
from typing import List
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers.list import ListOutputParser
......@@ -39,7 +41,7 @@ class GetSimilarityWithExt:
self.faiss_db = _faiss_db
self.similarity_docs = self.get_text_similarity_with_ext()
self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
self.rerank_docs = ""
self.rerank_docs = []
def get_rerank(self, reranker: BgeRerank, top_k=5):
question = '\n'.join(self.question)
......@@ -51,14 +53,38 @@ class GetSimilarityWithExt:
self.rerank_docs = rerank_docs[:top_k]
return self.faiss_db.join_document(d_list)
def join_document(self, docs: List[Document]) -> str:
if len(docs) == 0:
return "[]"
result = "["
d1 = docs[0]
page_number = "0"
if "page_number" in d1.metadata:
page_number = d1.metadata["page_number"]
result += ("{\"page_content\": \"" + d1.page_content + "\",\"filename\":\"" + d1.metadata["filename"] + "\",\"page_number\":\"" + page_number + "\"}")
docs = docs[1:]
for doc in docs:
page_number = "0"
if "page_number" in doc.metadata:
page_number = doc.metadata["page_number"]
result += ("{\"page_content\": \"" + doc.page_content + "\",\"filename\":\"" + doc.metadata["filename"] + "\",\"page_number\":\"" + page_number + "\"}")
return result
def get_rerank_with_doc(self, reranker: BgeRerank,split_doc:list, top_k=5):
question = '\n'.join(self.question)
print(question)
start = time.time()
rerank_docs1 = reranker.compress_documents(split_doc, question)
end = time.time()
print('重排1 time: %s Seconds' % (end - start))
start = time.time()
rerank_docs2 = reranker.compress_documents(self.similarity_docs, question)
end = time.time()
print('重排2 time: %s Seconds' % (end - start))
rerank_docs1_hash = []
rerank_docs2_hash = []
m = {}
start = time.time()
for doc in rerank_docs1:
m[hash(doc.page_content)] = doc
rerank_docs1_hash.append(hash(doc.page_content))
......@@ -66,19 +92,28 @@ class GetSimilarityWithExt:
for doc in rerank_docs2:
m[hash(doc.page_content)] = doc
rerank_docs2_hash.append(hash(doc.page_content))
end = time.time()
step_time1 = end - start
result = []
result.append((60,rerank_docs1_hash))
result.append((55,rerank_docs2_hash))
print(len(rerank_docs1_hash))
print(len(rerank_docs2_hash))
start = time.time()
rrf_doc = reciprocal_rank_fusion(result)
end = time.time()
print('混排 time: %s Seconds' % (end - start))
print(rrf_doc)
d_list = []
start = time.time()
for key in rrf_doc:
d_list.append(m[key])
end = time.time()
step_time2 = end - start
print('文档去重 time: %s Seconds' % (step_time1 + step_time2))
self.rerank_docs = d_list[:top_k]
return self.faiss_db.join_document(d_list[:top_k])
return self.join_document(d_list[:top_k])
def get_similarity_doc(self):
return self.similarity_doc_txt
......
import json
import sys
sys.path.append('../')
......@@ -16,7 +17,7 @@ from src.server.agent import create_chart_agent
from src.pgdb.knowledge.k_db import PostgresDB
from src.pgdb.knowledge.txt_doc_table import TxtDoc
from src.agent.tool_divisions import AdministrativeDivision
from src.agent.rag_agent import RAGQuery
from src.agent.rag_agent import RAGQuery,RAGAnalysisQuery
from src.config.consts import (
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
......@@ -31,7 +32,7 @@ from src.config.consts import (
prompt1
)
from src.config.prompts import PROMPT_AGENT_CHART_SYS_VARS,PROMPT_AGENT_CHART_SYS,PROMPT_AGENT_CHAT_HUMAN,PROMPT_AGENT_CHAT_HUMAN_VARS
from src.config.prompts import PROMPT_AGENT_SYS_VARS,PROMPT_AGENT_SYS,PROMPT_AGENT_CHAT_HUMAN,PROMPT_AGENT_CHAT_HUMAN_VARS
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
......@@ -62,11 +63,11 @@ tools = [AdministrativeDivision(),RAGQuery(vecstore_faiss,ext,PromptTemplate(inp
# input_variables=['agent_scratchpad', 'input', 'tool_names', 'tools','chart_tool']
input_variables=[]
input_variables.extend(PROMPT_AGENT_CHAT_HUMAN_VARS)
input_variables.extend(PROMPT_AGENT_CHART_SYS_VARS)
input_variables.extend(PROMPT_AGENT_SYS_VARS)
input_types={'chat_history': List[Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}
messages=[
# SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=['tool_names', 'tools'], template='Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n{tools}\n\nUse a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).\n\nValid "action" values: "Final Answer" or {tool_names}\n\nProvide only ONE action per $JSON_BLOB, as shown:\n\n```\n{{\n "action": $TOOL_NAME,\n "action_input": $INPUT\n}}\n```\n\nFollow this format:\n\nQuestion: input question to answer\nThought: consider previous and subsequent steps\nAction:\n```\n$JSON_BLOB\n```\nObservation: action result\n... (repeat Thought/Action/Observation N times)\nThought: I know what to respond\nAction:\n```\n{{\n "action": "Final Answer",\n "action_input": "Final response to human"\n}}\n\nBegin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation')),
SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_CHART_SYS_VARS, template=PROMPT_AGENT_CHART_SYS)),
SystemMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_SYS_VARS, template=PROMPT_AGENT_SYS)),
MessagesPlaceholder(variable_name='chat_history', optional=True),
# HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['agent_scratchpad', 'input'], template='{input}\n\n{agent_scratchpad}\n (reminder to respond in a JSON blob no matter what)'))
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=PROMPT_AGENT_CHAT_HUMAN_VARS, template=PROMPT_AGENT_CHAT_HUMAN))
......@@ -82,7 +83,7 @@ prompt = ChatPromptTemplate(
# agent = create_structured_chat_agent(llm=base_llm, tools=tools, prompt=prompt)
agent = create_chart_agent(base_llm, tools, prompt, chart_tool="chart")
agent_executor = AgentExecutor(agent=agent, tools=tools,verbose=True,handle_parsing_errors=True)
agent_executor = AgentExecutor(agent=agent, tools=tools,verbose=True,handle_parsing_errors=True,return_intermediate_steps=True)
history = []
h1 = []
h1.append("攸县年降雨量")
......@@ -101,3 +102,7 @@ res = agent_executor.invoke({"input":"攸县、长沙县、化隆县和大通县
print("====== result: ======")
print(res)
print(type(res))
print(res["output"])
for step in res["intermediate_steps"]:
print(step[1]["参考文档"])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment