diff --git a/src/agent/rag_agent.py b/src/agent/rag_agent.py index 81107d5..fa1dcff 100644 --- a/src/agent/rag_agent.py +++ b/src/agent/rag_agent.py @@ -42,8 +42,8 @@ class RAGQuery(BaseTool): self.llm_chain = _llm_chain - def get_similarity_with_ext_origin(self, _ext): - return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db) + def get_similarity_with_ext_origin(self, _ext,_location): + return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db,_location=_location) @@ -75,20 +75,20 @@ class RAGQuery(BaseTool): matches = re.findall(r'"([^"]+)"', result.content) print(matches) - similarity = self.get_similarity_with_ext_origin(matches) + similarity = self.get_similarity_with_ext_origin(matches,_location=location) # cur_similarity = similarity.get_rerank(self.rerank_model) cur_similarity = similarity.get_rerank_with_doc(self.rerank_model,split_docs) - docs = similarity.get_rerank_docs() + # docs = similarity.get_rerank_docs() # print(cur_similarity) # # geo_result = "以下是详细的水文气象地质资料:"+cur_similarity+"\n 以下是原问题"+question # # cur_question = self.prompt.format(history=history, context=cur_similarity, question=question) cur_answer = self.llm_chain.run(context=cur_similarity, question=question) # print(cur_answer) # return cur_answer - loc = location[0] - location = location[1:] - for i in location: - loc += (","+i) + # loc = location[0] + # location = location[1:] + # for i in location: + # loc += (","+i) return {"详细信息":cur_answer,"参考文档": cur_similarity} diff --git a/src/controller/web.py b/src/controller/web.py index b576680..bdbddeb 100644 --- a/src/controller/web.py +++ b/src/controller/web.py @@ -226,8 +226,9 @@ def question(chat_request: ChatRequest, token: str = Header(None)): # answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True) docs_json = [] for step in res["intermediate_steps"]: - j = json.loads(step[1]["参考文档"], strict=False) - docs_json.extend(j) + if "rag_query" == step[0].tool: + j = json.loads(step[1]["参考文档"], strict=False) + docs_json.extend(j) print(len(docs_json)) doc_hash = [] for d in docs_json: @@ -283,8 +284,9 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): answer = res["output"] docs_json = [] for step in res["intermediate_steps"]: - j = json.loads(step[1]["参考文档"], strict=False) - docs_json.extend(j) + if "rag_query" == step[0].tool: + j = json.loads(step[1]["参考文档"], strict=False) + docs_json.extend(j) doc_hash = [] for d in docs_json: diff --git a/src/server/get_similarity.py b/src/server/get_similarity.py index 05c5f20..7ae1b71 100644 --- a/src/server/get_similarity.py +++ b/src/server/get_similarity.py @@ -36,13 +36,16 @@ class GetSimilarity: class GetSimilarityWithExt: - def __init__(self, _question, _faiss_db: VectorStore_FAISS): + def __init__(self, _question, _faiss_db: VectorStore_FAISS, _location=None): self.question = _question self.faiss_db = _faiss_db + self.location = _location self.similarity_docs = self.get_text_similarity_with_ext() self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs) self.rerank_docs = [] + + def get_rerank(self, reranker: BgeRerank, top_k=5): question = '\n'.join(self.question) print(question) @@ -139,10 +142,13 @@ class GetSimilarityWithExt: content_set = set() unique_documents = [] for doc in similarity_docs: - content = hash(doc.page_content) - if content not in content_set: - unique_documents.append(doc) - content_set.add(content) + if self.location is not None: + if any(substring in doc.page_content for substring in self.location) or any(substring in doc.metadata["filename"] for substring in self.location) : + content = hash(doc.page_content) + if content not in content_set: + unique_documents.append(doc) + content_set.add(content) + print(len(unique_documents)) return unique_documents class QAExt: diff --git a/test/rag_agent_test.py b/test/rag_agent_test.py index 029d098..543a19c 100644 --- a/test/rag_agent_test.py +++ b/test/rag_agent_test.py @@ -106,10 +106,13 @@ print(type(res)) print(res["output"]) docs_json = [] for step in res["intermediate_steps"]: - print(type(step[1]["参考文档"])) - print(step[1]["参考文档"]) - j = json.loads(step[1]["参考文档"],strict=False) - docs_json.extend(j) + print(type(step[0].tool)) + print(step[0].tool) + if "rag_query" ==step[0].tool: + print(True) + print(step[1]["参考文档"]) + j = json.loads(step[1]["参考文档"], strict=False) + docs_json.extend(j) print(docs_json) print(len(docs_json)) \ No newline at end of file