similarity.py 15.3 KB
Newer Older
陈正乐 committed
1 2 3
import os
import sys
import re
4 5 6
from os import path

import copy
陈正乐 committed
7
from typing import List, OrderedDict, Any, Optional, Tuple, Dict
8
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
陈正乐 committed
9
from langchain.vectorstores.faiss import FAISS, dependable_faiss_import
10
from langchain.schema import Document
陈正乐 committed
11
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore
12 13 14 15 16 17 18 19 20 21 22 23 24
from langchain.embeddings.huggingface import (
    HuggingFaceEmbeddings,
)
import math
import faiss
from langchain.vectorstores.utils import DistanceStrategy
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from src.loader import load
from langchain.embeddings.base import Embeddings
陈正乐 committed
25
from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
陈正乐 committed
26 27 28 29
import operator
from langchain.vectorstores.utils import DistanceStrategy
import numpy as np
sys.path.append("../")
陈正乐 committed
30

31 32 33

def singleton(cls):
    instances = {}
陈正乐 committed
34

35 36 37 38
    def get_instance(*args, **kwargs):
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]
陈正乐 committed
39

40 41
    return get_instance

陈正乐 committed
42

43 44
@singleton
class EmbeddingFactory:
陈正乐 committed
45 46 47
    def __init__(self, _path: str):
        self.path = _path
        self.embedding = HuggingFaceEmbeddings(model_name=_path)
48 49 50 51

    def get_embedding(self):
        return self.embedding

陈正乐 committed
52

陈正乐 committed
53
def get_embding(_path: str) -> Embeddings:
54
    # return HuggingFaceEmbeddings(model_name=path)
陈正乐 committed
55 56 57
    return EmbeddingFactory(_path).get_embedding()


陈正乐 committed
58

陈正乐 committed
59 60


61
class RE_FAISS(FAISS):
陈正乐 committed
62 63 64
    # 去重,并保留metadate
    @staticmethod
    def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
65
        deduplicated_dict = OrderedDict()
66
        print("--------------oedereddict type--------------", type(deduplicated_dict))
陈正乐 committed
67
        for doc, scores in tuple_input:
68 69 70
            page_content = doc.page_content
            metadata = doc.metadata
            if page_content not in deduplicated_dict:
陈正乐 committed
71
                deduplicated_dict[page_content] = (metadata, scores)
72
        print("--------------------------du--------------------------\n", deduplicated_dict)
陈正乐 committed
73 74
        deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
                                  deduplicated_dict.items()]
75
        return deduplicated_documents
陈正乐 committed
76

77
    def similarity_search_with_score_by_vector(
陈正乐 committed
78 79 80 81 82 83
            self,
            embedding: List[float],
            k: int = 4,
            filter: Optional[Dict[str, Any]] = None,
            fetch_k: int = 20,
            **kwargs: Any,
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
    ) -> List[Tuple[Document, float]]:
        faiss = dependable_faiss_import()
        vector = np.array([embedding], dtype=np.float32)
        if self._normalize_L2:
            faiss.normalize_L2(vector)
        scores, indices = self.index.search(vector, k if filter is None else fetch_k)
        docs = []
        for j, i in enumerate(indices[0]):
            if i == -1:
                # This happens when not enough docs are returned.
                continue
            _id = self.index_to_docstore_id[i]
            doc = self.docstore.search(_id)
            if not isinstance(doc, Document):
                raise ValueError(f"Could not find document for id {_id}, got {doc}")
            if filter is not None:
                filter = {
                    key: [value] if not isinstance(value, list) else value
                    for key, value in filter.items()
                }
                if all(doc.metadata.get(key) in value for key, value in filter.items()):
                    docs.append((doc, scores[0][j]))
            else:
                docs.append((doc, scores[0][j]))
        docs = self._tuple_deduplication(docs)
        score_threshold = kwargs.get("score_threshold")
        if score_threshold is not None:
            cmp = (
                operator.ge
                if self.distance_strategy
陈正乐 committed
114
                   in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
115 116 117 118 119 120 121
                else operator.le
            )
            docs = [
                (doc, similarity)
                for doc, similarity in docs
                if cmp(similarity, score_threshold)
            ]
陈正乐 committed
122

123 124
        if "doc_callback" in kwargs:
            if hasattr(kwargs["doc_callback"], 'after_search'):
陈正乐 committed
125
                docs = kwargs["doc_callback"].after_search(self.docstore, docs, number=k)
126
        return docs[:k]
陈正乐 committed
127

128
    def max_marginal_relevance_search_by_vector(
陈正乐 committed
129 130 131 132 133 134 135
            self,
            embedding: List[float],
            k: int = 4,
            fetch_k: int = 20,
            lambda_mult: float = 0.5,
            filter: Optional[Dict[str, Any]] = None,
            **kwargs: Any,
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
            embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
        )
        docs_and_scores = self._tuple_deduplication(docs_and_scores)
        if "doc_callback" in kwargs:
            if hasattr(kwargs["doc_callback"], 'after_search'):
陈正乐 committed
160
                docs_and_scores = kwargs["doc_callback"].after_search(self.docstore, docs_and_scores, number=k)
161
        return [doc for doc, _ in docs_and_scores]
陈正乐 committed
162 163 164 165


def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
             is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
陈正乐 committed
166
    embeddings = get_embding(_path=embedding_model_name)
陈正乐 committed
167
    docstore1: PgSqlDocstore = None
168 169
    if is_pgsql:
        if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
陈正乐 committed
170
            docstore1 = PgSqlDocstore(info, reset=reset)
171 172 173
    else:
        docstore1 = InMemorySecondaryDocstore()
    if not path.exists(store_path):
陈正乐 committed
174 175 176
        os.makedirs(store_path, exist_ok=True)
    if store_path is None or len(store_path) <= 0 or not path.exists(
            path.join(store_path, index_name + ".faiss")) or reset:
177
        print("create new faiss")
陈正乐 committed
178 179 180
        index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0]))  # 根据embeddings向量维度设置
        return RE_FAISS(embedding_function=embeddings.client.encode, index=index, docstore=docstore1,
                        index_to_docstore_id={})
181 182
    else:
        print("load_local faiss")
陈正乐 committed
183 184
        _faiss = RE_FAISS.load_local(folder_path=store_path, index_name=index_name, embeddings=embeddings)
        if docstore1 and is_pgsql:  # 如果外部参数调整,更新docstore
185 186
            _faiss.docstore = docstore1
        return _faiss
陈正乐 committed
187 188


189
class VectorStore_FAISS(FAISS):
陈正乐 committed
190 191 192
    def __init__(self, embedding_model_name: str, store_path: str, index_name: str = "index", info: dict = None,
                 is_pgsql: bool = True, show_number=5, threshold=0.8, reset: bool = False,
                 doc_callback: DocumentCallback = DefaultDocumentCallback()):
193 194
        self.info = info
        self.embedding_model_name = embedding_model_name
陈正乐 committed
195
        self.store_path = path.join(store_path, index_name)
196
        if not path.exists(self.store_path):
陈正乐 committed
197
            os.makedirs(self.store_path, exist_ok=True)
198 199
        self.index_name = index_name
        self.show_number = show_number
陈正乐 committed
200
        self.search_number = self.show_number * 3
201
        self.threshold = threshold
陈正乐 committed
202 203
        self._faiss = getFAISS(self.embedding_model_name, self.store_path, info=info, index_name=self.index_name,
                               is_pgsql=is_pgsql, reset=reset)
204
        self.doc_callback = doc_callback
陈正乐 committed
205 206 207 208 209 210

    def get_text_similarity_with_score(self, text: str, **kwargs):
        score_threshold = (1 - self.threshold) * math.sqrt(2)
        docs = self._faiss.similarity_search_with_score(query=text, k=self.search_number,
                                                        score_threshold=score_threshold, doc_callback=self.doc_callback,
                                                        **kwargs)
211
        return [doc for doc, similarity in docs][:self.show_number]
陈正乐 committed
212 213 214

    def get_text_similarity(self, text: str, **kwargs):
        docs = self._faiss.similarity_search(query=text, k=self.search_number, doc_callback=self.doc_callback, **kwargs)
215
        return docs[:self.show_number]
陈正乐 committed
216

217 218 219 220 221 222 223 224
    # #去重,并保留metadate
    # def _tuple_deduplication(self, tuple_input:List[Document]) -> List[Document]:
    #     deduplicated_dict = OrderedDict()
    #     for doc in tuple_input:
    #         page_content = doc.page_content
    #         metadata = doc.metadata
    #         if page_content not in deduplicated_dict:
    #             deduplicated_dict[page_content] = metadata
陈正乐 committed
225

226 227
    #     deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
    #     return deduplicated_documents
陈正乐 committed
228 229 230

    @staticmethod
    def _join_document(docs: List[Document]) -> str:
231 232
        print(docs)
        return "".join([doc.page_content for doc in docs])
陈正乐 committed
233 234 235

    @staticmethod
    def get_local_doc(docs: List[Document]):
236 237
        ans = []
        for doc in docs:
陈正乐 committed
238 239
            ans.append({"page_content": doc.page_content, "page_number": doc.metadata["page_number"],
                        "filename": doc.metadata["filename"]})
240 241 242 243 244 245
        return ans

    # def _join_document_location(self, docs:List[Document]) -> str:

    # 持久化到本地
    def _save_local(self):
陈正乐 committed
246 247
        self._faiss.save_local(folder_path=self.store_path, index_name=self.index_name)

248 249 250 251 252 253 254
    # 添加文档
    # Document {
    # page_content 段落
    # metadata {
    #    page 页码    
    #    }    
    # }
陈正乐 committed
255 256
    def _add_documents(self, new_docs: List[Document], need_split: bool = True, pattern: str = r'[?。;\n]'):
        list_of_documents: List[Document] = []
257
        if self.doc_callback:
陈正乐 committed
258
            new_docs = self.doc_callback.before_store(self._faiss.docstore, new_docs)
259 260 261 262 263 264 265 266 267 268 269 270 271 272
        if need_split:
            for doc in new_docs:
                words_list = re.split(pattern, doc.page_content)
                # 去掉重复项
                words_list = set(words_list)
                words_list = [str(words) for words in words_list]
                for words in words_list:
                    if not words.strip() == '':
                        metadata = copy.deepcopy(doc.metadata)
                        metadata["paragraph"] = doc.page_content
                        list_of_documents.append(Document(page_content=words, metadata=metadata))
        else:
            list_of_documents = new_docs
        self._faiss.add_documents(list_of_documents)
陈正乐 committed
273 274 275 276 277 278 279 280

    def _add_documents_from_dir(self, filepaths=None, load_kwargs=None):
        if load_kwargs is None:
            load_kwargs = {"mode": "paged"}
        if filepaths is None:
            filepaths = []
        self._add_documents(load.loads(filepaths, **load_kwargs))

281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
    def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
        """
        Return VectorStoreRetriever initialized from this VectorStore.

        Args:
            search_type (Optional[str]): Defines the type of search that
                the Retriever should perform.
                Can be "similarity" (default), "mmr", or
                "similarity_score_threshold".
            search_kwargs (Optional[Dict]): Keyword arguments to pass to the
                search function. Can include things like:
                    k: Amount of documents to return (Default: 4)
                    score_threshold: Minimum relevance threshold
                        for similarity_score_threshold
                    fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
                    lambda_mult: Diversity of results returned by MMR;
                        1 for minimum diversity and 0 for maximum. (Default: 0.5)
                    filter: Filter by document metadata

        Returns:
            VectorStoreRetriever: Retriever class for VectorStore.

        Examples:

陈正乐 committed
305
        . code-block:: python
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341

            # Retrieve more documents with higher diversity
            # Useful if your dataset has many similar documents
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 6, 'lambda_mult': 0.25}
            )

            # Fetch more documents for the MMR algorithm to consider
            # But only return the top 5
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 5, 'fetch_k': 50}
            )

            # Only retrieve documents that have a relevance score
            # Above a certain threshold
            docsearch.as_retriever(
                search_type="similarity_score_threshold",
                search_kwargs={'score_threshold': 0.8}
            )

            # Only get the single most similar document from the dataset
            docsearch.as_retriever(search_kwargs={'k': 1})

            # Use a filter to only retrieve documents from a specific paper
            docsearch.as_retriever(
                search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
            )
    """
        if not kwargs or "similarity_score_threshold" != kwargs["search_type"]:
            default_kwargs = {'k': self.show_number}
            if "search_kwargs" in kwargs:
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
        elif "similarity_score_threshold" == kwargs["search_type"]:
陈正乐 committed
342 343
            default_kwargs = {'score_threshold': self.threshold, 'k': self.show_number}
            if "search_kwargs" in kwargs:
344 345
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
陈正乐 committed
346
        kwargs["search_kwargs"]["doc_callback"] = self.doc_callback
347 348 349 350 351 352 353 354
        tags = kwargs.pop("tags", None) or []
        tags.extend(self._faiss._get_retriever_tags())
        print(kwargs)
        return VectorStoreRetriever_FAISS(vectorstore=self._faiss, **kwargs, tags=tags)


class VectorStoreRetriever_FAISS(VectorStoreRetriever):
    search_k = 5
陈正乐 committed
355 356

    def __init__(self, **kwargs):
357 358
        super().__init__(**kwargs)
        if "k" in self.search_kwargs:
陈正乐 committed
359 360 361
            self.search_k = self.search_kwargs["k"]
            self.search_kwargs["k"] = self.search_k * 2

362
    def _get_relevant_documents(
陈正乐 committed
363
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
364
    ) -> List[Document]:
陈正乐 committed
365
        docs = super()._get_relevant_documents(query=query, run_manager=run_manager)
366
        return docs[:self.search_k]
陈正乐 committed
367

368
    async def _aget_relevant_documents(
陈正乐 committed
369
            self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
370
    ) -> List[Document]:
陈正乐 committed
371
        docs = super()._aget_relevant_documents(query=query, run_manager=run_manager)
372
        return docs[:self.search_k]