Commit 135a0fe7 by 文靖昊

修改agent工具描述,分批获取相似文档,返回文档数量进行修改

parent 909d09a7
...@@ -22,7 +22,7 @@ class IssuanceArgs(BaseModel): ...@@ -22,7 +22,7 @@ class IssuanceArgs(BaseModel):
class RAGQuery(BaseTool): class RAGQuery(BaseTool):
name = "rag_query" name = "rag_query"
description = """这是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取单个区(县)的水文气象地质等相关信息,当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并一一获取具体的区(县)的水文气象地质等相关信息。这个知识库中信息并不全面,有可能缺失。""" description = """这是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取一个区(县)的水文气象地质等相关信息。如果问题中有多个区县,请拆解出来,并一个区县一个区县的查询。当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并一一获取具体的区(县)的水文气象地质等相关信息。这个知识库中信息并不全面,有可能缺失。"""
args_schema: Type[BaseModel] = IssuanceArgs args_schema: Type[BaseModel] = IssuanceArgs
rerank: Any # 替换 Any 为适当的类型 rerank: Any # 替换 Any 为适当的类型
rerank_model: Any # 替换 Any 为适当的类型 rerank_model: Any # 替换 Any 为适当的类型
...@@ -53,17 +53,19 @@ class RAGQuery(BaseTool): ...@@ -53,17 +53,19 @@ class RAGQuery(BaseTool):
# split_list = [] # split_list = []
# for l in split_str: # for l in split_str:
# split_list.append(l) # split_list.append(l)
split_docs_list = []
start = time.time() for l in location:
answer = self.db.find_like_doc(location) start = time.time()
end = time.time() answer = self.db.find_like_doc(l)
print('find_like_doc time: %s Seconds' % (end - start)) end = time.time()
print(len(answer) if answer else 0) print('find_like_doc time: %s Seconds' % (end - start))
split_docs = [] print(len(answer) if answer else 0)
for a in answer if answer else []: split_docs = []
d = Document(page_content=a[0], metadata=json.loads(a[1])) for a in answer if answer else []:
split_docs.append(d) d = Document(page_content=a[0], metadata=json.loads(a[1]))
print(len(split_docs)) split_docs.append(d)
print(len(split_docs))
split_docs_list.append(split_docs)
# if len(split_docs)>10: # if len(split_docs)>10:
# split_docs= split_docs[:10] # split_docs= split_docs[:10]
...@@ -77,7 +79,7 @@ class RAGQuery(BaseTool): ...@@ -77,7 +79,7 @@ class RAGQuery(BaseTool):
print(matches) print(matches)
similarity = self.get_similarity_with_ext_origin(matches,_location=location) similarity = self.get_similarity_with_ext_origin(matches,_location=location)
# cur_similarity = similarity.get_rerank(self.rerank_model) # cur_similarity = similarity.get_rerank(self.rerank_model)
cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs) cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs_list)
# docs = similarity.get_rerank_docs() # docs = similarity.get_rerank_docs()
# print(cur_similarity) # print(cur_similarity)
# # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question # # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question
......
...@@ -132,7 +132,7 @@ class AdministrativeDivisionArgs(BaseModel): ...@@ -132,7 +132,7 @@ class AdministrativeDivisionArgs(BaseModel):
class AdministrativeDivision(BaseTool): class AdministrativeDivision(BaseTool):
name = "administrative_division" name = "administrative_division"
description = "根据用户提问中涉及到的地区信息补全其行政区划信息,明确具体的省、市、县信息。比如输入县,补全所属省市,输入市则补全省级以及下辖所有县区" description = "根据用户提问中涉及到的地区信息补全其行政区划信息,明确具体的省、市、县信息。比如输入县,补全所属省市,输入市则补全省级以及下辖所有县区,当问题中涉及区县的时候,一定要优先调用此工具"
args_schema: Type[BaseModel] = AdministrativeDivisionArgs args_schema: Type[BaseModel] = AdministrativeDivisionArgs
def _run(self, input_text: str) -> str: def _run(self, input_text: str) -> str:
......
...@@ -66,16 +66,9 @@ class TxtDoc: ...@@ -66,16 +66,9 @@ class TxtDoc:
print("drop table txt_doc ok") print("drop table txt_doc ok")
def find_like_doc(self,item:list): def find_like_doc(self,item:str):
print(item)
i0 = item[0] query = "select text,matadate FROM txt_doc WHERE matadate like '%"+item+"%' or text like '%"+item+"%' "
if len(item)>=1:
item = item[1:]
print(item)
query = "select text,matadate FROM txt_doc WHERE matadate like '%"+i0+"%' or text like '%"+i0+"%' "
for i in item:
query+= "or matadate like '%"+i+"%' or text like '%"+i0+"%' "
print(query)
self.db.execute(query) self.db.execute(query)
answer = self.db.fetchall() answer = self.db.fetchall()
......
...@@ -80,33 +80,31 @@ class GetSimilarityWithExt: ...@@ -80,33 +80,31 @@ class GetSimilarityWithExt:
result += "]" result += "]"
return result return result
def get_rerank_with_doc(self, reranker: BgeRerank,split_doc:list, top_k=5): def get_rerank_with_doc(self, reranker: BgeRerank,split_docs_list:list):
top_k = self.get_doc_nums(len(split_docs_list))
question = '\n'.join(self.question) question = '\n'.join(self.question)
print(question) print(question)
start = time.time()
rerank_docs1 = reranker.compress_documents(split_doc, question)
end = time.time()
print('重排1 time: %s Seconds' % (end - start))
start = time.time()
rerank_docs2 = reranker.compress_documents(self.similarity_docs, question)
end = time.time()
print('重排2 time: %s Seconds' % (end - start))
rerank_docs1_hash = [] rerank_docs1_hash = []
rerank_docs2_hash = [] rerank_docs2_hash = []
m = {} m = {}
result = []
for split_doc in split_docs_list:
start = time.time()
rerank_docs1 = reranker.compress_documents(split_doc, question)
end = time.time()
print('重排1 time: %s Seconds' % (end - start))
for doc in rerank_docs1:
m[hash(doc.page_content)] = doc
rerank_docs1_hash.append(hash(doc.page_content))
result.append((60, rerank_docs1_hash))
start = time.time() start = time.time()
for doc in rerank_docs1: rerank_docs2 = reranker.compress_documents(self.similarity_docs, question)
m[hash(doc.page_content)] = doc end = time.time()
rerank_docs1_hash.append(hash(doc.page_content)) print('重排2 time: %s Seconds' % (end - start))
for doc in rerank_docs2: for doc in rerank_docs2:
m[hash(doc.page_content)] = doc m[hash(doc.page_content)] = doc
rerank_docs2_hash.append(hash(doc.page_content)) rerank_docs2_hash.append(hash(doc.page_content))
end = time.time()
step_time1 = end - start
result = []
result.append((60,rerank_docs1_hash))
result.append((55,rerank_docs2_hash)) result.append((55,rerank_docs2_hash))
print(len(rerank_docs1_hash)) print(len(rerank_docs1_hash))
print(len(rerank_docs2_hash)) print(len(rerank_docs2_hash))
...@@ -114,14 +112,11 @@ class GetSimilarityWithExt: ...@@ -114,14 +112,11 @@ class GetSimilarityWithExt:
rrf_doc = reciprocal_rank_fusion(result) rrf_doc = reciprocal_rank_fusion(result)
end = time.time() end = time.time()
print('混排 time: %s Seconds' % (end - start)) print('混排 time: %s Seconds' % (end - start))
print(rrf_doc) print("混排文档数量:", len(rrf_doc))
d_list = [] d_list = []
start = time.time()
for key in rrf_doc: for key in rrf_doc:
d_list.append(m[key]) d_list.append(m[key])
end = time.time() print("返回文档数量:",top_k)
step_time2 = end - start
print('文档去重 time: %s Seconds' % (step_time1 + step_time2))
self.rerank_docs = d_list[:top_k] self.rerank_docs = d_list[:top_k]
return self.join_document(d_list[:top_k]) return self.join_document(d_list[:top_k])
...@@ -151,6 +146,15 @@ class GetSimilarityWithExt: ...@@ -151,6 +146,15 @@ class GetSimilarityWithExt:
print(len(unique_documents)) print(len(unique_documents))
return unique_documents return unique_documents
def get_doc_nums(self,num :int)->int:
num = num*3
if num<5:
return 5
elif num>30:
return 30
else:
return num
class QAExt: class QAExt:
llm = None llm = None
......
...@@ -10,7 +10,8 @@ from langchain_core.prompts.chat import ChatPromptTemplate,HumanMessagePromptTem ...@@ -10,7 +10,8 @@ from langchain_core.prompts.chat import ChatPromptTemplate,HumanMessagePromptTem
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain.chains import LLMChain from langchain.chains import LLMChain
import langchain_core import langchain_core
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.get_similarity import QAExt from src.server.get_similarity import QAExt
from src.server.agent import create_chart_agent from src.server.agent import create_chart_agent
...@@ -41,7 +42,8 @@ base_llm = ChatOpenAI( ...@@ -41,7 +42,8 @@ base_llm = ChatOpenAI(
verbose=True, verbose=True,
temperature=0 temperature=0
) )
# base_llm = ChatERNIESerLLM(
# chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
vecstore_faiss = VectorStore_FAISS( vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH, embedding_model_name=EMBEEDING_MODEL_PATH,
...@@ -98,7 +100,7 @@ for h in history: ...@@ -98,7 +100,7 @@ for h in history:
prompt += "问:{}\n答:{}\n".format(h[0], h[1]) prompt += "问:{}\n答:{}\n".format(h[0], h[1])
print(prompt) print(prompt)
# res = agent_executor.invoke({"input":"以下历史对话记录: "+prompt+"以下是问题:"+"攸县、长沙县、化隆县和大通县谁的年平均降雨量大"}) # res = agent_executor.invoke({"input":"以下历史对话记录: "+prompt+"以下是问题:"+"攸县、长沙县、化隆县和大通县谁的年平均降雨量大"})
res = agent_executor.invoke({"input":"攸县、长沙县、化隆县和大通县谁的年平均降雨量大","histories":history}) res = agent_executor.invoke({"input":"西宁市各区县年平均降雨量","histories":history})
print("====== result: ======") print("====== result: ======")
print(res) print(res)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment