similarity.py 15.1 KB
Newer Older
陈正乐 committed
1 2 3
import os
import sys
import re
4 5 6
from os import path

import copy
陈正乐 committed
7
from typing import List, OrderedDict, Any, Optional, Tuple, Dict
8
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
陈正乐 committed
9
from langchain.vectorstores.faiss import FAISS, dependable_faiss_import
10
from langchain.schema import Document
陈正乐 committed
11
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore
12 13 14 15 16 17 18 19 20 21 22 23 24
from langchain.embeddings.huggingface import (
    HuggingFaceEmbeddings,
)
import math
import faiss
from langchain.vectorstores.utils import DistanceStrategy
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from src.loader import load
from langchain.embeddings.base import Embeddings
陈正乐 committed
25
from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
陈正乐 committed
26 27 28 29
import operator
from langchain.vectorstores.utils import DistanceStrategy
import numpy as np
sys.path.append("../")
陈正乐 committed
30

31 32 33

def singleton(cls):
    instances = {}
陈正乐 committed
34

35 36 37 38
    def get_instance(*args, **kwargs):
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]
陈正乐 committed
39

40 41
    return get_instance

陈正乐 committed
42

43 44
@singleton
class EmbeddingFactory:
陈正乐 committed
45 46 47
    def __init__(self, _path: str):
        self.path = _path
        self.embedding = HuggingFaceEmbeddings(model_name=_path)
48 49 50 51

    def get_embedding(self):
        return self.embedding

陈正乐 committed
52

陈正乐 committed
53
def get_embding(_path: str) -> Embeddings:
54
    # return HuggingFaceEmbeddings(model_name=path)
陈正乐 committed
55 56 57
    return EmbeddingFactory(_path).get_embedding()


陈正乐 committed
58

陈正乐 committed
59 60


61
class RE_FAISS(FAISS):
陈正乐 committed
62 63 64
    # 去重,并保留metadate
    @staticmethod
    def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
65
        deduplicated_dict = OrderedDict()
陈正乐 committed
66
        for doc, scores in tuple_input:
67 68 69
            page_content = doc.page_content
            metadata = doc.metadata
            if page_content not in deduplicated_dict:
陈正乐 committed
70 71 72
                deduplicated_dict[page_content] = (metadata, scores)
        deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
                                  deduplicated_dict.items()]
73
        return deduplicated_documents
陈正乐 committed
74

75
    def similarity_search_with_score_by_vector(
陈正乐 committed
76 77 78 79 80 81
            self,
            embedding: List[float],
            k: int = 4,
            filter: Optional[Dict[str, Any]] = None,
            fetch_k: int = 20,
            **kwargs: Any,
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    ) -> List[Tuple[Document, float]]:
        faiss = dependable_faiss_import()
        vector = np.array([embedding], dtype=np.float32)
        if self._normalize_L2:
            faiss.normalize_L2(vector)
        scores, indices = self.index.search(vector, k if filter is None else fetch_k)
        docs = []
        for j, i in enumerate(indices[0]):
            if i == -1:
                # This happens when not enough docs are returned.
                continue
            _id = self.index_to_docstore_id[i]
            doc = self.docstore.search(_id)
            if not isinstance(doc, Document):
                raise ValueError(f"Could not find document for id {_id}, got {doc}")
            if filter is not None:
                filter = {
                    key: [value] if not isinstance(value, list) else value
                    for key, value in filter.items()
                }
                if all(doc.metadata.get(key) in value for key, value in filter.items()):
                    docs.append((doc, scores[0][j]))
            else:
                docs.append((doc, scores[0][j]))
        docs = self._tuple_deduplication(docs)
        score_threshold = kwargs.get("score_threshold")
        if score_threshold is not None:
            cmp = (
                operator.ge
                if self.distance_strategy
陈正乐 committed
112
                   in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
113 114 115 116 117 118 119
                else operator.le
            )
            docs = [
                (doc, similarity)
                for doc, similarity in docs
                if cmp(similarity, score_threshold)
            ]
陈正乐 committed
120

121 122
        if "doc_callback" in kwargs:
            if hasattr(kwargs["doc_callback"], 'after_search'):
陈正乐 committed
123
                docs = kwargs["doc_callback"].after_search(self.docstore, docs, number=k)
124
        return docs[:k]
陈正乐 committed
125

126
    def max_marginal_relevance_search_by_vector(
陈正乐 committed
127 128 129 130 131 132 133
            self,
            embedding: List[float],
            k: int = 4,
            fetch_k: int = 20,
            lambda_mult: float = 0.5,
            filter: Optional[Dict[str, Any]] = None,
            **kwargs: Any,
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
            embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
        )
        docs_and_scores = self._tuple_deduplication(docs_and_scores)
        if "doc_callback" in kwargs:
            if hasattr(kwargs["doc_callback"], 'after_search'):
陈正乐 committed
158
                docs_and_scores = kwargs["doc_callback"].after_search(self.docstore, docs_and_scores, number=k)
159
        return [doc for doc, _ in docs_and_scores]
陈正乐 committed
160 161 162 163


def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
             is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
陈正乐 committed
164
    embeddings = get_embding(_path=embedding_model_name)
陈正乐 committed
165
    docstore1: PgSqlDocstore = None
166 167
    if is_pgsql:
        if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
陈正乐 committed
168
            docstore1 = PgSqlDocstore(info, reset=reset)
169 170 171
    else:
        docstore1 = InMemorySecondaryDocstore()
    if not path.exists(store_path):
陈正乐 committed
172 173 174
        os.makedirs(store_path, exist_ok=True)
    if store_path is None or len(store_path) <= 0 or not path.exists(
            path.join(store_path, index_name + ".faiss")) or reset:
175
        print("create new faiss")
陈正乐 committed
176 177 178
        index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0]))  # 根据embeddings向量维度设置
        return RE_FAISS(embedding_function=embeddings.client.encode, index=index, docstore=docstore1,
                        index_to_docstore_id={})
179 180
    else:
        print("load_local faiss")
陈正乐 committed
181 182
        _faiss = RE_FAISS.load_local(folder_path=store_path, index_name=index_name, embeddings=embeddings)
        if docstore1 and is_pgsql:  # 如果外部参数调整,更新docstore
183 184
            _faiss.docstore = docstore1
        return _faiss
陈正乐 committed
185 186


187
class VectorStore_FAISS(FAISS):
陈正乐 committed
188 189 190
    def __init__(self, embedding_model_name: str, store_path: str, index_name: str = "index", info: dict = None,
                 is_pgsql: bool = True, show_number=5, threshold=0.8, reset: bool = False,
                 doc_callback: DocumentCallback = DefaultDocumentCallback()):
191 192
        self.info = info
        self.embedding_model_name = embedding_model_name
陈正乐 committed
193
        self.store_path = path.join(store_path, index_name)
194
        if not path.exists(self.store_path):
陈正乐 committed
195
            os.makedirs(self.store_path, exist_ok=True)
196 197
        self.index_name = index_name
        self.show_number = show_number
陈正乐 committed
198
        self.search_number = self.show_number * 3
199
        self.threshold = threshold
陈正乐 committed
200 201
        self._faiss = getFAISS(self.embedding_model_name, self.store_path, info=info, index_name=self.index_name,
                               is_pgsql=is_pgsql, reset=reset)
202
        self.doc_callback = doc_callback
陈正乐 committed
203 204 205 206 207 208

    def get_text_similarity_with_score(self, text: str, **kwargs):
        score_threshold = (1 - self.threshold) * math.sqrt(2)
        docs = self._faiss.similarity_search_with_score(query=text, k=self.search_number,
                                                        score_threshold=score_threshold, doc_callback=self.doc_callback,
                                                        **kwargs)
209
        return [doc for doc, similarity in docs][:self.show_number]
陈正乐 committed
210 211 212

    def get_text_similarity(self, text: str, **kwargs):
        docs = self._faiss.similarity_search(query=text, k=self.search_number, doc_callback=self.doc_callback, **kwargs)
213
        return docs[:self.show_number]
陈正乐 committed
214

215 216 217 218 219 220 221 222
    # #去重,并保留metadate
    # def _tuple_deduplication(self, tuple_input:List[Document]) -> List[Document]:
    #     deduplicated_dict = OrderedDict()
    #     for doc in tuple_input:
    #         page_content = doc.page_content
    #         metadata = doc.metadata
    #         if page_content not in deduplicated_dict:
    #             deduplicated_dict[page_content] = metadata
陈正乐 committed
223

224 225
    #     deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
    #     return deduplicated_documents
陈正乐 committed
226 227 228

    @staticmethod
    def _join_document(docs: List[Document]) -> str:
229 230
        print(docs)
        return "".join([doc.page_content for doc in docs])
陈正乐 committed
231 232 233

    @staticmethod
    def get_local_doc(docs: List[Document]):
234 235
        ans = []
        for doc in docs:
陈正乐 committed
236 237
            ans.append({"page_content": doc.page_content, "page_number": doc.metadata["page_number"],
                        "filename": doc.metadata["filename"]})
238 239 240 241 242 243
        return ans

    # def _join_document_location(self, docs:List[Document]) -> str:

    # 持久化到本地
    def _save_local(self):
陈正乐 committed
244 245
        self._faiss.save_local(folder_path=self.store_path, index_name=self.index_name)

246 247 248 249 250 251 252
    # 添加文档
    # Document {
    # page_content 段落
    # metadata {
    #    page 页码    
    #    }    
    # }
陈正乐 committed
253 254
    def _add_documents(self, new_docs: List[Document], need_split: bool = True, pattern: str = r'[?。;\n]'):
        list_of_documents: List[Document] = []
255
        if self.doc_callback:
陈正乐 committed
256
            new_docs = self.doc_callback.before_store(self._faiss.docstore, new_docs)
257 258 259 260 261 262 263 264 265 266 267 268 269 270
        if need_split:
            for doc in new_docs:
                words_list = re.split(pattern, doc.page_content)
                # 去掉重复项
                words_list = set(words_list)
                words_list = [str(words) for words in words_list]
                for words in words_list:
                    if not words.strip() == '':
                        metadata = copy.deepcopy(doc.metadata)
                        metadata["paragraph"] = doc.page_content
                        list_of_documents.append(Document(page_content=words, metadata=metadata))
        else:
            list_of_documents = new_docs
        self._faiss.add_documents(list_of_documents)
陈正乐 committed
271 272 273 274 275 276 277 278

    def _add_documents_from_dir(self, filepaths=None, load_kwargs=None):
        if load_kwargs is None:
            load_kwargs = {"mode": "paged"}
        if filepaths is None:
            filepaths = []
        self._add_documents(load.loads(filepaths, **load_kwargs))

279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
    def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
        """
        Return VectorStoreRetriever initialized from this VectorStore.

        Args:
            search_type (Optional[str]): Defines the type of search that
                the Retriever should perform.
                Can be "similarity" (default), "mmr", or
                "similarity_score_threshold".
            search_kwargs (Optional[Dict]): Keyword arguments to pass to the
                search function. Can include things like:
                    k: Amount of documents to return (Default: 4)
                    score_threshold: Minimum relevance threshold
                        for similarity_score_threshold
                    fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
                    lambda_mult: Diversity of results returned by MMR;
                        1 for minimum diversity and 0 for maximum. (Default: 0.5)
                    filter: Filter by document metadata

        Returns:
            VectorStoreRetriever: Retriever class for VectorStore.

        Examples:

陈正乐 committed
303
        . code-block:: python
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339

            # Retrieve more documents with higher diversity
            # Useful if your dataset has many similar documents
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 6, 'lambda_mult': 0.25}
            )

            # Fetch more documents for the MMR algorithm to consider
            # But only return the top 5
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 5, 'fetch_k': 50}
            )

            # Only retrieve documents that have a relevance score
            # Above a certain threshold
            docsearch.as_retriever(
                search_type="similarity_score_threshold",
                search_kwargs={'score_threshold': 0.8}
            )

            # Only get the single most similar document from the dataset
            docsearch.as_retriever(search_kwargs={'k': 1})

            # Use a filter to only retrieve documents from a specific paper
            docsearch.as_retriever(
                search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
            )
    """
        if not kwargs or "similarity_score_threshold" != kwargs["search_type"]:
            default_kwargs = {'k': self.show_number}
            if "search_kwargs" in kwargs:
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
        elif "similarity_score_threshold" == kwargs["search_type"]:
陈正乐 committed
340 341
            default_kwargs = {'score_threshold': self.threshold, 'k': self.show_number}
            if "search_kwargs" in kwargs:
342 343
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
陈正乐 committed
344
        kwargs["search_kwargs"]["doc_callback"] = self.doc_callback
345 346 347 348 349 350 351 352
        tags = kwargs.pop("tags", None) or []
        tags.extend(self._faiss._get_retriever_tags())
        print(kwargs)
        return VectorStoreRetriever_FAISS(vectorstore=self._faiss, **kwargs, tags=tags)


class VectorStoreRetriever_FAISS(VectorStoreRetriever):
    search_k = 5
陈正乐 committed
353 354

    def __init__(self, **kwargs):
355 356
        super().__init__(**kwargs)
        if "k" in self.search_kwargs:
陈正乐 committed
357 358 359
            self.search_k = self.search_kwargs["k"]
            self.search_kwargs["k"] = self.search_k * 2

360
    def _get_relevant_documents(
陈正乐 committed
361
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
362
    ) -> List[Document]:
陈正乐 committed
363
        docs = super()._get_relevant_documents(query=query, run_manager=run_manager)
364
        return docs[:self.search_k]
陈正乐 committed
365

366
    async def _aget_relevant_documents(
陈正乐 committed
367
            self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
368
    ) -> List[Document]:
陈正乐 committed
369
        docs = super()._aget_relevant_documents(query=query, run_manager=run_manager)
370
        return docs[:self.search_k]