from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.config.prompts import PROMPT_QUERY_EXTEND,PROMPT_QA_EXTEND_QUESTION from src.server.rerank import BgeRerank,reciprocal_rank_fusion import time from typing import List from langchain_core.documents import Document from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers.list import ListOutputParser class GetSimilarity: def __init__(self, _question, _faiss_db: VectorStore_FAISS): self.question = _question self.faiss_db = _faiss_db self.similarity_docs = self.faiss_db.get_text_similarity(self.question) self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs) self.rerank_docs = "" def get_rerank(self,reranker:BgeRerank ,top_k = 5): rerank_docs = reranker.compress_documents(self.similarity_docs,self.question) d_list = [] for d in rerank_docs[:top_k]: d_list.append(d) self.rerank_docs = rerank_docs[:top_k] return self.faiss_db.join_document(d_list) def get_similarity_doc(self): return self.similarity_doc_txt def get_similarity_docs(self): return self.similarity_docs def get_rerank_docs(self): return self.rerank_docs class GetSimilarityWithExt: def __init__(self, _question, _faiss_db: VectorStore_FAISS, _location=None): self.question = _question self.faiss_db = _faiss_db self.location = _location self.similarity_docs = self.get_text_similarity_with_ext() self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs) self.rerank_docs = [] def get_rerank(self, reranker: BgeRerank, top_k=5): question = '\n'.join(self.question) print(question) rerank_docs = reranker.compress_documents(self.similarity_docs, question) d_list = [] for d in rerank_docs[:top_k]: d_list.append(d) self.rerank_docs = rerank_docs[:top_k] return self.faiss_db.join_document(d_list) def join_document(self, docs: List[Document]) -> str: if len(docs) == 0: return "[]" result = "[" d1 = docs[0] page_number = "0" if "page_number" in d1.metadata: page_number = str(d1.metadata["page_number"]) hash_str = "" if "hash" in d1.metadata: hash_str = d1.metadata["hash"] result += ("{\"page_content\": \"" + d1.page_content + "\",\"from_file\":\"" + d1.metadata["filename"] + "\",\"hash\":\"" + hash_str + "\",\"page_number\":\"" + page_number + "\"}") docs = docs[1:] for doc in docs: page_number = "0" if "page_number" in doc.metadata: page_number = str(doc.metadata["page_number"]) hash_str = "" if "hash" in doc.metadata: hash_str = doc.metadata["hash"] result += (",{\"page_content\": \"" + doc.page_content + "\",\"from_file\":\"" + doc.metadata["filename"] + "\",\"hash\":\"" + hash_str + "\",\"page_number\":\"" + page_number + "\"}") result += "]" return result def get_rerank_with_doc(self, reranker: BgeRerank,split_docs_list:list): # top_k = self.get_doc_nums(len(split_docs_list)) question = '\n'.join(self.question) print(question) # rerank_docs1_hash = [] # rerank_docs2_hash = [] # m = {} result = [] for split_doc in split_docs_list: start = time.time() rerank_docs1 = reranker.compress_documents(split_doc, question) result.extend(rerank_docs1[:3]) end = time.time() print('重排1 time: %s Seconds' % (end - start)) # for doc in rerank_docs1: # m[hash(doc.page_content)] = doc # rerank_docs1_hash.append(hash(doc.page_content)) # result.append((60, rerank_docs1_hash)) start = time.time() rerank_docs2 = reranker.compress_documents(self.similarity_docs, question) result.extend(rerank_docs2[:3]) end = time.time() print('重排2 time: %s Seconds' % (end - start)) # for doc in rerank_docs2: # m[hash(doc.page_content)] = doc # rerank_docs2_hash.append(hash(doc.page_content)) # # result.append((55,rerank_docs2_hash)) # print(len(rerank_docs1_hash)) # print(len(rerank_docs2_hash)) # start = time.time() # rrf_doc = reciprocal_rank_fusion(result) # end = time.time() # print('混排 time: %s Seconds' % (end - start)) # print("混排文档数量:", len(rrf_doc)) # d_list = [] # for key in rrf_doc: # d_list.append(m[key]) # print("返回文档数量:",top_k) self.rerank_docs = result return self.join_document(result) def get_similarity_doc(self): return self.similarity_doc_txt def get_similarity_docs(self): return self.similarity_docs def get_rerank_docs(self): return self.rerank_docs def get_text_similarity_with_ext(self): similarity_docs = [] for q in self.question: similarity_doc = self.faiss_db.get_text_similarity(q) similarity_docs.extend(similarity_doc) content_set = set() unique_documents = [] for doc in similarity_docs: if self.location is not None: if any(substring in doc.page_content for substring in self.location) or any(substring in doc.metadata["filename"] for substring in self.location) : content = hash(doc.page_content) if content not in content_set: unique_documents.append(doc) content_set.add(content) print(len(unique_documents)) return unique_documents def get_doc_nums(self,num :int)->int: num = num*3 if num<5: return 5 elif num>30: return 30 else: return num class QAExt: llm = None def __init__(self, llm) -> None: self.llm = llm prompt = PromptTemplate.from_template(PROMPT_QUERY_EXTEND) # parser = ListOutputParser() self.query_extend = prompt | llm def extend_query(self, question, messages=None): """ question: str messages: list of tuple (str,str) eg: [ ("Q1","A1"), ("Q2","A2"), ... ] """ if not messages: messages = [] history = "" for msg in messages: history += f"Q: {msg[0]}\nA: {msg[1]}\n" return self.query_extend.invoke(input={"histories":messages, "query":question}) def extend_query_with_str(self, question, messages): """ question: str messages: list of tuple (str,str) eg: [ ("Q1","A1"), ("Q2","A2"), ... ] """ return self.query_extend.invoke(input={"histories": messages, "query": question}) class ChatExtend: def __init__(self, llm) -> None: self.llm = llm prompt = PromptTemplate.from_template(PROMPT_QA_EXTEND_QUESTION) self.query_extend = prompt | llm def new_questions(self, messages): history = "" for msg in messages: history += f"Q: {msg[0]}\nA: {msg[1]}\n" return self.query_extend.invoke(input={"histories":history})