Commit ec3277e8 by 文靖昊

向量相似性搜索之后过滤不含地名的结果

parent 130b6514
...@@ -42,8 +42,8 @@ class RAGQuery(BaseTool): ...@@ -42,8 +42,8 @@ class RAGQuery(BaseTool):
self.llm_chain = _llm_chain self.llm_chain = _llm_chain
def get_similarity_with_ext_origin(self, _ext): def get_similarity_with_ext_origin(self, _ext,_location):
return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db) return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location)
...@@ -75,20 +75,20 @@ class RAGQuery(BaseTool): ...@@ -75,20 +75,20 @@ class RAGQuery(BaseTool):
matches = re.findall(r'"([^"]+)"', result.content) matches = re.findall(r'"([^"]+)"', result.content)
print(matches) print(matches)
similarity = self.get_similarity_with_ext_origin(matches) similarity = self.get_similarity_with_ext_origin(matches,_location=location)
# cur_similarity = similarity.get_rerank(self.rerank_model) # cur_similarity = similarity.get_rerank(self.rerank_model)
cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs) cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs)
docs = similarity.get_rerank_docs() # docs = similarity.get_rerank_docs()
# print(cur_similarity) # print(cur_similarity)
# # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question # # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question
# # cur_question = self.prompt.format(history=history, context=cur_similarity, question=question) # # cur_question = self.prompt.format(history=history, context=cur_similarity, question=question)
cur_answer = self.llm_chain.run(context=cur_similarity, question=question) cur_answer = self.llm_chain.run(context=cur_similarity, question=question)
# print(cur_answer) # print(cur_answer)
# return cur_answer # return cur_answer
loc = location[0] # loc = location[0]
location = location[1:] # location = location[1:]
for i in location: # for i in location:
loc += (","+i) # loc += (","+i)
return {"详细信息":cur_answer,"参考文档": cur_similarity} return {"详细信息":cur_answer,"参考文档": cur_similarity}
......
...@@ -226,6 +226,7 @@ def question(chat_request: ChatRequest, token: str = Header(None)): ...@@ -226,6 +226,7 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
# answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True) # answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True)
docs_json = [] docs_json = []
for step in res["intermediate_steps"]: for step in res["intermediate_steps"]:
if "rag_query" == step[0].tool:
j = json.loads(step[1]["参考文档"], strict=False) j = json.loads(step[1]["参考文档"], strict=False)
docs_json.extend(j) docs_json.extend(j)
print(len(docs_json)) print(len(docs_json))
...@@ -283,6 +284,7 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): ...@@ -283,6 +284,7 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
answer = res["output"] answer = res["output"]
docs_json = [] docs_json = []
for step in res["intermediate_steps"]: for step in res["intermediate_steps"]:
if "rag_query" == step[0].tool:
j = json.loads(step[1]["参考文档"], strict=False) j = json.loads(step[1]["参考文档"], strict=False)
docs_json.extend(j) docs_json.extend(j)
......
...@@ -36,13 +36,16 @@ class GetSimilarity: ...@@ -36,13 +36,16 @@ class GetSimilarity:
class GetSimilarityWithExt: class GetSimilarityWithExt:
def __init__(self, _question, _faiss_db: VectorStore_FAISS): def __init__(self, _question, _faiss_db: VectorStore_FAISS, _location=None):
self.question = _question self.question = _question
self.faiss_db = _faiss_db self.faiss_db = _faiss_db
self.location = _location
self.similarity_docs = self.get_text_similarity_with_ext() self.similarity_docs = self.get_text_similarity_with_ext()
self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs) self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
self.rerank_docs = [] self.rerank_docs = []
def get_rerank(self, reranker: BgeRerank, top_k=5): def get_rerank(self, reranker: BgeRerank, top_k=5):
question = '\n'.join(self.question) question = '\n'.join(self.question)
print(question) print(question)
...@@ -139,10 +142,13 @@ class GetSimilarityWithExt: ...@@ -139,10 +142,13 @@ class GetSimilarityWithExt:
content_set = set() content_set = set()
unique_documents = [] unique_documents = []
for doc in similarity_docs: for doc in similarity_docs:
if self.location is not None:
if any(substring in doc.page_content for substring in self.location) or any(substring in doc.metadata["filename"] for substring in self.location) :
content = hash(doc.page_content) content = hash(doc.page_content)
if content not in content_set: if content not in content_set:
unique_documents.append(doc) unique_documents.append(doc)
content_set.add(content) content_set.add(content)
print(len(unique_documents))
return unique_documents return unique_documents
class QAExt: class QAExt:
......
...@@ -106,9 +106,12 @@ print(type(res)) ...@@ -106,9 +106,12 @@ print(type(res))
print(res["output"]) print(res["output"])
docs_json = [] docs_json = []
for step in res["intermediate_steps"]: for step in res["intermediate_steps"]:
print(type(step[1]["参考文档"])) print(type(step[0].tool))
print(step[0].tool)
if "rag_query" ==step[0].tool:
print(True)
print(step[1]["参考文档"]) print(step[1]["参考文档"])
j = json.loads(step[1]["参考文档"],strict=False) j = json.loads(step[1]["参考文档"], strict=False)
docs_json.extend(j) docs_json.extend(j)
print(docs_json) print(docs_json)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment