get_similarity.py 3.75 KB
Newer Older
1
from src.pgdb.knowledge.similarity import VectorStore_FAISS
2
from src.config.prompts import PROMPT_QUERY_EXTEND,PROMPT_QA_EXTEND_QUESTION
tinywell committed
3
from .rerank import BgeRerank
4

5 6 7 8 9

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers.list import ListOutputParser


10 11 12 13
class GetSimilarity:
    def __init__(self, _question, _faiss_db: VectorStore_FAISS):
        self.question = _question
        self.faiss_db = _faiss_db
14 15
        self.similarity_docs = self.faiss_db.get_text_similarity(self.question)
        self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
16
        self.rerank_docs = ""
17

tinywell committed
18
    def get_rerank(self,reranker:BgeRerank ,top_k = 5):
19
        rerank_docs = reranker.compress_documents(self.similarity_docs,self.question)
20 21 22 23 24
        d_list = []
        for d in rerank_docs[:top_k]:
            d_list.append(d)
        self.rerank_docs = rerank_docs[:top_k]
        return self.faiss_db.join_document(d_list)
tinywell committed
25

26
    def get_similarity_doc(self):
27 28 29 30
        return self.similarity_doc_txt
    
    def get_similarity_docs(self):
        return self.similarity_docs
tinywell committed
31 32 33
    
    def get_rerank_docs(self):
        return self.rerank_docs
34

文靖昊 committed
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

class GetSimilarityWithExt:
    def __init__(self, _question, _faiss_db: VectorStore_FAISS):
        self.question = _question
        self.faiss_db = _faiss_db
        self.similarity_docs = self.get_text_similarity_with_ext()
        self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
        self.rerank_docs = ""

    def get_rerank(self, reranker: BgeRerank, top_k=5):
        question = '\n'.join(self.question)
        print(question)
        rerank_docs = reranker.compress_documents(self.similarity_docs, question)
        d_list = []
        for d in rerank_docs[:top_k]:
            d_list.append(d)
        self.rerank_docs = rerank_docs[:top_k]
        return self.faiss_db.join_document(d_list)

    def get_similarity_doc(self):
        return self.similarity_doc_txt

    def get_similarity_docs(self):
        return self.similarity_docs

    def get_rerank_docs(self):
        return self.rerank_docs

    def get_text_similarity_with_ext(self):
        similarity_docs = []
        for q in self.question:
            similarity_doc = self.faiss_db.get_text_similarity(q)
            similarity_docs.extend(similarity_doc)
文靖昊 committed
68 69 70 71 72 73 74 75
        content_set = set()
        unique_documents = []
        for doc in similarity_docs:
            content = hash(doc.page_content)
            if content not in content_set:
                unique_documents.append(doc)
                content_set.add(content)
        return unique_documents
文靖昊 committed
76

77 78 79 80 81
class QAExt:
    llm = None

    def __init__(self, llm) -> None:
        self.llm = llm
tinywell committed
82
        prompt = PromptTemplate.from_template(PROMPT_QUERY_EXTEND)
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        # parser = ListOutputParser()
        self.query_extend = prompt | llm 

    def extend_query(self, question, messages=None):
        """
            question: str
            messages: list of tuple (str,str)
                eg:
                [
                    ("Q1","A1"),
                    ("Q2","A2"),
                    ...
                ]
            
        """
        if not messages:
            messages = []
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories":messages, "query":question})

105 106 107 108 109 110 111 112 113 114 115
class ChatExtend:
    def __init__(self, llm) -> None:
        self.llm = llm
        prompt = PromptTemplate.from_template(PROMPT_QA_EXTEND_QUESTION)
        self.query_extend = prompt | llm

    def new_questions(self, messages):
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories":history})