get_similarity.py 5.24 KB
Newer Older
1
from src.pgdb.knowledge.similarity import VectorStore_FAISS
2
from src.config.prompts import PROMPT_QUERY_EXTEND,PROMPT_QA_EXTEND_QUESTION
3
from src.server.rerank import BgeRerank,reciprocal_rank_fusion
4

5 6 7 8 9

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers.list import ListOutputParser


10 11 12 13
class GetSimilarity:
    def __init__(self, _question, _faiss_db: VectorStore_FAISS):
        self.question = _question
        self.faiss_db = _faiss_db
14 15
        self.similarity_docs = self.faiss_db.get_text_similarity(self.question)
        self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
16
        self.rerank_docs = ""
17

tinywell committed
18
    def get_rerank(self,reranker:BgeRerank ,top_k = 5):
19
        rerank_docs = reranker.compress_documents(self.similarity_docs,self.question)
20 21 22 23 24
        d_list = []
        for d in rerank_docs[:top_k]:
            d_list.append(d)
        self.rerank_docs = rerank_docs[:top_k]
        return self.faiss_db.join_document(d_list)
tinywell committed
25

26
    def get_similarity_doc(self):
27 28 29 30
        return self.similarity_doc_txt
    
    def get_similarity_docs(self):
        return self.similarity_docs
tinywell committed
31 32 33
    
    def get_rerank_docs(self):
        return self.rerank_docs
34

文靖昊 committed
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

class GetSimilarityWithExt:
    def __init__(self, _question, _faiss_db: VectorStore_FAISS):
        self.question = _question
        self.faiss_db = _faiss_db
        self.similarity_docs = self.get_text_similarity_with_ext()
        self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
        self.rerank_docs = ""

    def get_rerank(self, reranker: BgeRerank, top_k=5):
        question = '\n'.join(self.question)
        print(question)
        rerank_docs = reranker.compress_documents(self.similarity_docs, question)
        d_list = []
        for d in rerank_docs[:top_k]:
            d_list.append(d)
文靖昊 committed
51 52 53 54 55 56
        self.rerank_docs = rerank_docs[:top_k]
        return self.faiss_db.join_document(d_list)

    def get_rerank_with_doc(self, reranker: BgeRerank,split_doc:list, top_k=5):
        question = '\n'.join(self.question)
        print(question)
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        rerank_docs1 = reranker.compress_documents(split_doc, question)
        rerank_docs2 = reranker.compress_documents(self.similarity_docs, question)
        rerank_docs1_hash = []
        rerank_docs2_hash = []
        m = {}
        for doc in rerank_docs1:
            m[hash(doc.page_content)] = doc
            rerank_docs1_hash.append(hash(doc.page_content))

        for doc in rerank_docs2:
            m[hash(doc.page_content)] = doc
            rerank_docs2_hash.append(hash(doc.page_content))
        result = []
        result.append((60,rerank_docs1_hash))
        result.append((55,rerank_docs2_hash))
        print(len(rerank_docs1_hash))
        print(len(rerank_docs2_hash))
        rrf_doc = reciprocal_rank_fusion(result)
        print(rrf_doc)
文靖昊 committed
76
        d_list = []
77 78 79 80 81
        for key in rrf_doc:
            d_list.append(m[key])

        self.rerank_docs = d_list[:top_k]
        return self.faiss_db.join_document(d_list[:top_k])
文靖昊 committed
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

    def get_similarity_doc(self):
        return self.similarity_doc_txt

    def get_similarity_docs(self):
        return self.similarity_docs

    def get_rerank_docs(self):
        return self.rerank_docs

    def get_text_similarity_with_ext(self):
        similarity_docs = []
        for q in self.question:
            similarity_doc = self.faiss_db.get_text_similarity(q)
            similarity_docs.extend(similarity_doc)
文靖昊 committed
97 98 99 100 101 102 103 104
        content_set = set()
        unique_documents = []
        for doc in similarity_docs:
            content = hash(doc.page_content)
            if content not in content_set:
                unique_documents.append(doc)
                content_set.add(content)
        return unique_documents
文靖昊 committed
105

106 107 108 109 110
class QAExt:
    llm = None

    def __init__(self, llm) -> None:
        self.llm = llm
tinywell committed
111
        prompt = PromptTemplate.from_template(PROMPT_QUERY_EXTEND)
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
        # parser = ListOutputParser()
        self.query_extend = prompt | llm 

    def extend_query(self, question, messages=None):
        """
            question: str
            messages: list of tuple (str,str)
                eg:
                [
                    ("Q1","A1"),
                    ("Q2","A2"),
                    ...
                ]
            
        """
        if not messages:
            messages = []
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories":messages, "query":question})

文靖昊 committed
134 135 136 137 138 139 140 141 142 143 144 145 146 147
    def extend_query_with_str(self, question, messages):
        """
            question: str
            messages: list of tuple (str,str)
                eg:
                [
                    ("Q1","A1"),
                    ("Q2","A2"),
                    ...
                ]

        """
        return self.query_extend.invoke(input={"histories": messages, "query": question})

148 149 150 151 152 153 154 155 156 157 158
class ChatExtend:
    def __init__(self, llm) -> None:
        self.llm = llm
        prompt = PromptTemplate.from_template(PROMPT_QA_EXTEND_QUESTION)
        self.query_extend = prompt | llm

    def new_questions(self, messages):
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories":history})