similarity.py 15.3 KB
Newer Older
陈正乐 committed
1 2
import os
import sys
陈正乐 committed
3
import re
4 5 6
from os import path

import copy
陈正乐 committed
7
from typing import List, OrderedDict, Any, Optional, Tuple, Dict
8
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
9 10
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
陈正乐 committed
11
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore
12 13 14 15 16
from langchain.embeddings.huggingface import (
    HuggingFaceEmbeddings,
)
import math
import faiss
17 18 19
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_core.callbacks import (
20 21 22 23
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from src.loader import load
24
from langchain_core.embeddings import Embeddings
陈正乐 committed
25
from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
陈正乐 committed
26 27
import operator
import numpy as np
陈正乐 committed
28
sys.path.append("../")
陈正乐 committed
29

30 31 32

def singleton(cls):
    instances = {}
陈正乐 committed
33

34 35 36 37
    def get_instance(*args, **kwargs):
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]
陈正乐 committed
38

39 40
    return get_instance

陈正乐 committed
41

42 43
@singleton
class EmbeddingFactory:
陈正乐 committed
44 45 46
    def __init__(self, _path: str):
        self.path = _path
        self.embedding = HuggingFaceEmbeddings(model_name=_path)
47 48 49 50

    def get_embedding(self):
        return self.embedding

陈正乐 committed
51

陈正乐 committed
52
def get_embding(_path: str) -> Embeddings:
53
    # return HuggingFaceEmbeddings(model_name=path)
陈正乐 committed
54 55 56
    return EmbeddingFactory(_path).get_embedding()


陈正乐 committed
57 58 59



60
class RE_FAISS(FAISS):
陈正乐 committed
61 62 63
    # 去重,并保留metadate
    @staticmethod
    def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
64
        deduplicated_dict = OrderedDict()
65
        print("--------------oedereddict type--------------", type(deduplicated_dict))
陈正乐 committed
66
        for doc, scores in tuple_input:
67 68 69
            page_content = doc.page_content
            metadata = doc.metadata
            if page_content not in deduplicated_dict:
陈正乐 committed
70
                deduplicated_dict[page_content] = (metadata, scores)
71
        print("--------------------------du--------------------------\n", deduplicated_dict)
陈正乐 committed
72 73
        deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
                                  deduplicated_dict.items()]
74
        return deduplicated_documents
陈正乐 committed
75

76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    # def similarity_search_with_score_by_vector(
    #         self,
    #         embedding: List[float],
    #         k: int = 4,
    #         filter: Optional[Dict[str, Any]] = None,
    #         fetch_k: int = 20,
    #         **kwargs: Any,
    # ) -> List[Tuple[Document, float]]:
    #     faiss = dependable_faiss_import()
    #     vector = np.array([embedding], dtype=np.float32)
    #     if self._normalize_L2:
    #         faiss.normalize_L2(vector)
    #     scores, indices = self.index.search(vector, k if filter is None else fetch_k)
    #     docs = []
    #     for j, i in enumerate(indices[0]):
    #         if i == -1:
    #             # This happens when not enough docs are returned.
    #             continue
    #         _id = self.index_to_docstore_id[i]
    #         doc = self.docstore.search(_id)
    #         if not isinstance(doc, Document):
    #             raise ValueError(f"Could not find document for id {_id}, got {doc}")
    #         if filter is not None:
    #             filter = {
    #                 key: [value] if not isinstance(value, list) else value
    #                 for key, value in filter.items()
    #             }
    #             if all(doc.metadata.get(key) in value for key, value in filter.items()):
    #                 docs.append((doc, scores[0][j]))
    #         else:
    #             docs.append((doc, scores[0][j]))
    #     docs = self._tuple_deduplication(docs)
    #     score_threshold = kwargs.get("score_threshold")
    #     if score_threshold is not None:
    #         cmp = (
    #             operator.ge
    #             if self.distance_strategy
    #                in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
    #             else operator.le
    #         )
    #         docs = [
    #             (doc, similarity)
    #             for doc, similarity in docs
    #             if cmp(similarity, score_threshold)
    #         ]
    #
    #     if "doc_callback" in kwargs:
    #         if hasattr(kwargs["doc_callback"], 'after_search'):
    #             docs = kwargs["doc_callback"].after_search(self.docstore, docs, number=k)
    #     return docs[:k]
陈正乐 committed
126

127
    def max_marginal_relevance_search_by_vector(
陈正乐 committed
128 129 130 131 132 133 134
            self,
            embedding: List[float],
            k: int = 4,
            fetch_k: int = 20,
            lambda_mult: float = 0.5,
            filter: Optional[Dict[str, Any]] = None,
            **kwargs: Any,
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
            embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
        )
        docs_and_scores = self._tuple_deduplication(docs_and_scores)
        if "doc_callback" in kwargs:
            if hasattr(kwargs["doc_callback"], 'after_search'):
陈正乐 committed
159
                docs_and_scores = kwargs["doc_callback"].after_search(self.docstore, docs_and_scores, number=k)
160
        return [doc for doc, _ in docs_and_scores]
陈正乐 committed
161 162 163 164


def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
             is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
陈正乐 committed
165
    embeddings = get_embding(_path=embedding_model_name)
陈正乐 committed
166
    docstore1: PgSqlDocstore = None
167 168
    if is_pgsql:
        if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
陈正乐 committed
169
            docstore1 = PgSqlDocstore(info, reset=reset)
170 171 172
    else:
        docstore1 = InMemorySecondaryDocstore()
    if not path.exists(store_path):
陈正乐 committed
173 174 175
        os.makedirs(store_path, exist_ok=True)
    if store_path is None or len(store_path) <= 0 or not path.exists(
            path.join(store_path, index_name + ".faiss")) or reset:
176
        print("create new faiss")
陈正乐 committed
177 178 179
        index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0]))  # 根据embeddings向量维度设置
        return RE_FAISS(embedding_function=embeddings.client.encode, index=index, docstore=docstore1,
                        index_to_docstore_id={})
180 181
    else:
        print("load_local faiss")
陈正乐 committed
182
        _faiss = RE_FAISS.load_local(folder_path=store_path, index_name=index_name, embeddings=embeddings, allow_dangerous_deserialization=True)
陈正乐 committed
183
        if docstore1 and is_pgsql:  # 如果外部参数调整,更新docstore
184 185
            _faiss.docstore = docstore1
        return _faiss
陈正乐 committed
186 187


188
class VectorStore_FAISS(FAISS):
陈正乐 committed
189 190 191
    def __init__(self, embedding_model_name: str, store_path: str, index_name: str = "index", info: dict = None,
                 is_pgsql: bool = True, show_number=5, threshold=0.8, reset: bool = False,
                 doc_callback: DocumentCallback = DefaultDocumentCallback()):
192 193
        self.info = info
        self.embedding_model_name = embedding_model_name
陈正乐 committed
194
        self.store_path = path.join(store_path, index_name)
195
        if not path.exists(self.store_path):
陈正乐 committed
196
            os.makedirs(self.store_path, exist_ok=True)
197 198
        self.index_name = index_name
        self.show_number = show_number
陈正乐 committed
199
        self.search_number = self.show_number * 3
200
        self.threshold = threshold
陈正乐 committed
201 202
        self._faiss = getFAISS(self.embedding_model_name, self.store_path, info=info, index_name=self.index_name,
                               is_pgsql=is_pgsql, reset=reset)
203
        self.doc_callback = doc_callback
陈正乐 committed
204 205 206 207 208 209

    def get_text_similarity_with_score(self, text: str, **kwargs):
        score_threshold = (1 - self.threshold) * math.sqrt(2)
        docs = self._faiss.similarity_search_with_score(query=text, k=self.search_number,
                                                        score_threshold=score_threshold, doc_callback=self.doc_callback,
                                                        **kwargs)
210
        return [doc for doc, similarity in docs][:self.show_number]
陈正乐 committed
211 212 213

    def get_text_similarity(self, text: str, **kwargs):
        docs = self._faiss.similarity_search(query=text, k=self.search_number, doc_callback=self.doc_callback, **kwargs)
214
        return docs[:self.show_number]
陈正乐 committed
215

216 217 218 219 220 221 222 223
    # #去重,并保留metadate
    # def _tuple_deduplication(self, tuple_input:List[Document]) -> List[Document]:
    #     deduplicated_dict = OrderedDict()
    #     for doc in tuple_input:
    #         page_content = doc.page_content
    #         metadata = doc.metadata
    #         if page_content not in deduplicated_dict:
    #             deduplicated_dict[page_content] = metadata
陈正乐 committed
224

225 226
    #     deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
    #     return deduplicated_documents
陈正乐 committed
227 228

    @staticmethod
229
    def join_document(docs: List[Document]) -> str:
230
        return "".join([doc.page_content for doc in docs])
陈正乐 committed
231 232 233

    @staticmethod
    def get_local_doc(docs: List[Document]):
234 235
        ans = []
        for doc in docs:
陈正乐 committed
236 237
            ans.append({"page_content": doc.page_content, "page_number": doc.metadata["page_number"],
                        "filename": doc.metadata["filename"]})
238 239 240 241 242 243
        return ans

    # def _join_document_location(self, docs:List[Document]) -> str:

    # 持久化到本地
    def _save_local(self):
陈正乐 committed
244 245
        self._faiss.save_local(folder_path=self.store_path, index_name=self.index_name)

246 247 248 249
    # 添加文档
    # Document {
    # page_content 段落
    # metadata {
陈正乐 committed
250 251
    #    page 页码
    #    }
252
    # }
陈正乐 committed
253 254
    def _add_documents(self, new_docs: List[Document], need_split: bool = True, pattern: str = r'[?。;\n]'):
        list_of_documents: List[Document] = []
255
        if self.doc_callback:
陈正乐 committed
256
            new_docs = self.doc_callback.before_store(self._faiss.docstore, new_docs)
257 258
        if need_split:
            for doc in new_docs:
陈正乐 committed
259
                words_list = re.split(pattern, doc.page_content)
260 261 262 263 264 265 266 267 268 269 270
                # 去掉重复项
                words_list = set(words_list)
                words_list = [str(words) for words in words_list]
                for words in words_list:
                    if not words.strip() == '':
                        metadata = copy.deepcopy(doc.metadata)
                        metadata["paragraph"] = doc.page_content
                        list_of_documents.append(Document(page_content=words, metadata=metadata))
        else:
            list_of_documents = new_docs
        self._faiss.add_documents(list_of_documents)
陈正乐 committed
271 272 273 274 275 276 277 278

    def _add_documents_from_dir(self, filepaths=None, load_kwargs=None):
        if load_kwargs is None:
            load_kwargs = {"mode": "paged"}
        if filepaths is None:
            filepaths = []
        self._add_documents(load.loads(filepaths, **load_kwargs))

279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
    def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
        """
        Return VectorStoreRetriever initialized from this VectorStore.

        Args:
            search_type (Optional[str]): Defines the type of search that
                the Retriever should perform.
                Can be "similarity" (default), "mmr", or
                "similarity_score_threshold".
            search_kwargs (Optional[Dict]): Keyword arguments to pass to the
                search function. Can include things like:
                    k: Amount of documents to return (Default: 4)
                    score_threshold: Minimum relevance threshold
                        for similarity_score_threshold
                    fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
                    lambda_mult: Diversity of results returned by MMR;
                        1 for minimum diversity and 0 for maximum. (Default: 0.5)
                    filter: Filter by document metadata

        Returns:
            VectorStoreRetriever: Retriever class for VectorStore.

        Examples:

陈正乐 committed
303
        . code-block:: python
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339

            # Retrieve more documents with higher diversity
            # Useful if your dataset has many similar documents
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 6, 'lambda_mult': 0.25}
            )

            # Fetch more documents for the MMR algorithm to consider
            # But only return the top 5
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 5, 'fetch_k': 50}
            )

            # Only retrieve documents that have a relevance score
            # Above a certain threshold
            docsearch.as_retriever(
                search_type="similarity_score_threshold",
                search_kwargs={'score_threshold': 0.8}
            )

            # Only get the single most similar document from the dataset
            docsearch.as_retriever(search_kwargs={'k': 1})

            # Use a filter to only retrieve documents from a specific paper
            docsearch.as_retriever(
                search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
            )
    """
        if not kwargs or "similarity_score_threshold" != kwargs["search_type"]:
            default_kwargs = {'k': self.show_number}
            if "search_kwargs" in kwargs:
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
        elif "similarity_score_threshold" == kwargs["search_type"]:
陈正乐 committed
340 341
            default_kwargs = {'score_threshold': self.threshold, 'k': self.show_number}
            if "search_kwargs" in kwargs:
342 343
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
陈正乐 committed
344
        kwargs["search_kwargs"]["doc_callback"] = self.doc_callback
345 346 347 348 349 350 351 352
        tags = kwargs.pop("tags", None) or []
        tags.extend(self._faiss._get_retriever_tags())
        print(kwargs)
        return VectorStoreRetriever_FAISS(vectorstore=self._faiss, **kwargs, tags=tags)


class VectorStoreRetriever_FAISS(VectorStoreRetriever):
    search_k = 5
陈正乐 committed
353 354

    def __init__(self, **kwargs):
355 356
        super().__init__(**kwargs)
        if "k" in self.search_kwargs:
陈正乐 committed
357 358 359
            self.search_k = self.search_kwargs["k"]
            self.search_kwargs["k"] = self.search_k * 2

360
    def _get_relevant_documents(
陈正乐 committed
361
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
362
    ) -> List[Document]:
陈正乐 committed
363
        docs = super()._get_relevant_documents(query=query, run_manager=run_manager)
364
        return docs[:self.search_k]
陈正乐 committed
365

366
    async def _aget_relevant_documents(
陈正乐 committed
367
            self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
368
    ) -> List[Document]:
陈正乐 committed
369
        docs = super()._aget_relevant_documents(query=query, run_manager=run_manager)
370
        return docs[:self.search_k]