get_similarity.py 7.87 KB
Newer Older
1
from src.pgdb.knowledge.similarity import VectorStore_FAISS
2
from src.config.prompts import PROMPT_QUERY_EXTEND,PROMPT_QA_EXTEND_QUESTION
3
from src.server.rerank import BgeRerank,reciprocal_rank_fusion
4 5 6
import time
from typing import List
from langchain_core.documents import Document
7 8 9 10 11

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers.list import ListOutputParser


12 13 14 15
class GetSimilarity:
    def __init__(self, _question, _faiss_db: VectorStore_FAISS):
        self.question = _question
        self.faiss_db = _faiss_db
16 17
        self.similarity_docs = self.faiss_db.get_text_similarity(self.question)
        self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
18
        self.rerank_docs = ""
19

tinywell committed
20
    def get_rerank(self,reranker:BgeRerank ,top_k = 5):
21
        rerank_docs = reranker.compress_documents(self.similarity_docs,self.question)
22 23 24 25 26
        d_list = []
        for d in rerank_docs[:top_k]:
            d_list.append(d)
        self.rerank_docs = rerank_docs[:top_k]
        return self.faiss_db.join_document(d_list)
tinywell committed
27

28
    def get_similarity_doc(self):
29 30 31 32
        return self.similarity_doc_txt
    
    def get_similarity_docs(self):
        return self.similarity_docs
tinywell committed
33 34 35
    
    def get_rerank_docs(self):
        return self.rerank_docs
36

文靖昊 committed
37 38

class GetSimilarityWithExt:
39
    def __init__(self, _question, _faiss_db: VectorStore_FAISS, _location=None):
文靖昊 committed
40 41
        self.question = _question
        self.faiss_db = _faiss_db
42
        self.location = _location
文靖昊 committed
43 44
        self.similarity_docs = self.get_text_similarity_with_ext()
        self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
45
        self.rerank_docs = []
文靖昊 committed
46

47 48


文靖昊 committed
49 50 51 52 53 54 55
    def get_rerank(self, reranker: BgeRerank, top_k=5):
        question = '\n'.join(self.question)
        print(question)
        rerank_docs = reranker.compress_documents(self.similarity_docs, question)
        d_list = []
        for d in rerank_docs[:top_k]:
            d_list.append(d)
文靖昊 committed
56 57 58
        self.rerank_docs = rerank_docs[:top_k]
        return self.faiss_db.join_document(d_list)

59 60 61 62 63 64 65
    def join_document(self, docs: List[Document]) -> str:
        if len(docs) == 0:
            return "[]"
        result = "["
        d1 = docs[0]
        page_number = "0"
        if "page_number" in d1.metadata:
文靖昊 committed
66
            page_number = str(d1.metadata["page_number"])
文靖昊 committed
67 68 69
        hash_str = ""
        if "hash" in d1.metadata:
            hash_str = d1.metadata["hash"]
70
        result += ("{\"page_content\": \"" + d1.page_content + "\",\"from_file\":\"" + d1.metadata["filename"] + "\",\"hash\":\"" + hash_str + "\",\"page_number\":\"" + page_number + "\"}")
71 72 73 74
        docs = docs[1:]
        for doc in docs:
            page_number = "0"
            if "page_number" in doc.metadata:
文靖昊 committed
75
                page_number = str(doc.metadata["page_number"])
文靖昊 committed
76 77 78
            hash_str = ""
            if "hash" in doc.metadata:
                hash_str = doc.metadata["hash"]
79
            result += (",{\"page_content\": \"" + doc.page_content + "\",\"from_file\":\"" + doc.metadata["filename"] + "\",\"hash\":\"" + hash_str + "\",\"page_number\":\"" + page_number + "\"}")
文靖昊 committed
80
        result += "]"
81 82
        return result

83
    def get_rerank_with_doc(self, reranker: BgeRerank,split_docs_list:list):
84
        # top_k = self.get_doc_nums(len(split_docs_list))
文靖昊 committed
85 86
        question = '\n'.join(self.question)
        print(question)
87 88 89
        # rerank_docs1_hash = []
        # rerank_docs2_hash = []
        # m = {}
90 91 92 93
        result = []
        for split_doc in split_docs_list:
            start = time.time()
            rerank_docs1 = reranker.compress_documents(split_doc, question)
94
            result.extend(rerank_docs1[:3])
95 96
            end = time.time()
            print('重排1 time: %s Seconds' % (end - start))
97 98 99 100
            # for doc in rerank_docs1:
            #     m[hash(doc.page_content)] = doc
            #     rerank_docs1_hash.append(hash(doc.page_content))
            # result.append((60, rerank_docs1_hash))
101
        start = time.time()
102
        rerank_docs2 = reranker.compress_documents(self.similarity_docs, question)
103
        result.extend(rerank_docs2[:3])
104 105
        end = time.time()
        print('重排2 time: %s Seconds' % (end - start))
106 107 108 109 110 111 112 113
        print(len(result))
        content_set = set()
        unique_documents = []
        for doc in result:
            content = hash(doc.page_content)
            if content not in content_set:
                unique_documents.append(doc)
                content_set.add(content)
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
        # for doc in rerank_docs2:
        #     m[hash(doc.page_content)] = doc
        #     rerank_docs2_hash.append(hash(doc.page_content))
        #
        # result.append((55,rerank_docs2_hash))
        # print(len(rerank_docs1_hash))
        # print(len(rerank_docs2_hash))
        # start = time.time()
        # rrf_doc = reciprocal_rank_fusion(result)
        # end = time.time()
        # print('混排 time: %s Seconds' % (end - start))
        # print("混排文档数量:", len(rrf_doc))
        # d_list = []
        # for key in rrf_doc:
        #     d_list.append(m[key])
        # print("返回文档数量:",top_k)
130 131 132
        print(len(unique_documents))
        self.rerank_docs = unique_documents
        return self.join_document(unique_documents)
文靖昊 committed
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

    def get_similarity_doc(self):
        return self.similarity_doc_txt

    def get_similarity_docs(self):
        return self.similarity_docs

    def get_rerank_docs(self):
        return self.rerank_docs

    def get_text_similarity_with_ext(self):
        similarity_docs = []
        for q in self.question:
            similarity_doc = self.faiss_db.get_text_similarity(q)
            similarity_docs.extend(similarity_doc)
文靖昊 committed
148 149 150
        content_set = set()
        unique_documents = []
        for doc in similarity_docs:
151 152 153 154 155 156 157
            if self.location is not None:
                if any(substring in doc.page_content for substring in self.location) or any(substring in doc.metadata["filename"] for substring in self.location) :
                    content = hash(doc.page_content)
                    if content not in content_set:
                        unique_documents.append(doc)
                        content_set.add(content)
        print(len(unique_documents))
文靖昊 committed
158
        return unique_documents
文靖昊 committed
159

160 161 162 163 164 165 166 167 168
    def get_doc_nums(self,num :int)->int:
        num = num*3
        if num<5:
            return 5
        elif num>30:
            return 30
        else:
            return num

169 170 171 172 173
class QAExt:
    llm = None

    def __init__(self, llm) -> None:
        self.llm = llm
tinywell committed
174
        prompt = PromptTemplate.from_template(PROMPT_QUERY_EXTEND)
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
        # parser = ListOutputParser()
        self.query_extend = prompt | llm 

    def extend_query(self, question, messages=None):
        """
            question: str
            messages: list of tuple (str,str)
                eg:
                [
                    ("Q1","A1"),
                    ("Q2","A2"),
                    ...
                ]
            
        """
        if not messages:
            messages = []
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories":messages, "query":question})

文靖昊 committed
197 198 199 200 201 202 203 204 205 206 207 208 209 210
    def extend_query_with_str(self, question, messages):
        """
            question: str
            messages: list of tuple (str,str)
                eg:
                [
                    ("Q1","A1"),
                    ("Q2","A2"),
                    ...
                ]

        """
        return self.query_extend.invoke(input={"histories": messages, "query": question})

211 212 213 214 215 216 217 218 219 220 221
class ChatExtend:
    def __init__(self, llm) -> None:
        self.llm = llm
        prompt = PromptTemplate.from_template(PROMPT_QA_EXTEND_QUESTION)
        self.query_extend = prompt | llm

    def new_questions(self, messages):
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories":history})