import os import sys import re from os import path import copy from typing import List, OrderedDict, Any, Optional, Tuple, Dict from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore from langchain.vectorstores.faiss import FAISS from langchain.schema import Document from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore from langchain.embeddings.huggingface import ( HuggingFaceEmbeddings, ) import math import faiss from langchain.vectorstores.utils import DistanceStrategy from langchain.vectorstores.base import VectorStoreRetriever from langchain.callbacks.manager import ( AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) from src.loader import load from langchain.embeddings.base import Embeddings from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback import operator from langchain.vectorstores.utils import DistanceStrategy import numpy as np sys.path.append("../") def singleton(cls): instances = {} def get_instance(*args, **kwargs): if cls not in instances: instances[cls] = cls(*args, **kwargs) return instances[cls] return get_instance @singleton class EmbeddingFactory: def __init__(self, _path: str): self.path = _path self.embedding = HuggingFaceEmbeddings(model_name=_path) def get_embedding(self): return self.embedding def get_embding(_path: str) -> Embeddings: # return HuggingFaceEmbeddings(model_name=path) return EmbeddingFactory(_path).get_embedding() class RE_FAISS(FAISS): # 去重,并保留metadate @staticmethod def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]: deduplicated_dict = OrderedDict() print("--------------oedereddict type--------------", type(deduplicated_dict)) for doc, scores in tuple_input: page_content = doc.page_content metadata = doc.metadata if page_content not in deduplicated_dict: deduplicated_dict[page_content] = (metadata, scores) print("--------------------------du--------------------------\n", deduplicated_dict) deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in deduplicated_dict.items()] return deduplicated_documents # def similarity_search_with_score_by_vector( # self, # embedding: List[float], # k: int = 4, # filter: Optional[Dict[str, Any]] = None, # fetch_k: int = 20, # **kwargs: Any, # ) -> List[Tuple[Document, float]]: # faiss = dependable_faiss_import() # vector = np.array([embedding], dtype=np.float32) # if self._normalize_L2: # faiss.normalize_L2(vector) # scores, indices = self.index.search(vector, k if filter is None else fetch_k) # docs = [] # for j, i in enumerate(indices[0]): # if i == -1: # # This happens when not enough docs are returned. # continue # _id = self.index_to_docstore_id[i] # doc = self.docstore.search(_id) # if not isinstance(doc, Document): # raise ValueError(f"Could not find document for id {_id}, got {doc}") # if filter is not None: # filter = { # key: [value] if not isinstance(value, list) else value # for key, value in filter.items() # } # if all(doc.metadata.get(key) in value for key, value in filter.items()): # docs.append((doc, scores[0][j])) # else: # docs.append((doc, scores[0][j])) # docs = self._tuple_deduplication(docs) # score_threshold = kwargs.get("score_threshold") # if score_threshold is not None: # cmp = ( # operator.ge # if self.distance_strategy # in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD) # else operator.le # ) # docs = [ # (doc, similarity) # for doc, similarity in docs # if cmp(similarity, score_threshold) # ] # # if "doc_callback" in kwargs: # if hasattr(kwargs["doc_callback"], 'after_search'): # docs = kwargs["doc_callback"].after_search(self.docstore, docs, number=k) # return docs[:k] def max_marginal_relevance_search_by_vector( self, embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, filter: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. Args: embedding: Embedding to look up documents similar to. k: Number of Documents to return. Defaults to 4. fetch_k: Number of Documents to fetch before filtering to pass to MMR algorithm. lambda_mult: Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5. Returns: List of Documents selected by maximal marginal relevance. """ docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter ) docs_and_scores = self._tuple_deduplication(docs_and_scores) if "doc_callback" in kwargs: if hasattr(kwargs["doc_callback"], 'after_search'): docs_and_scores = kwargs["doc_callback"].after_search(self.docstore, docs_and_scores, number=k) return [doc for doc, _ in docs_and_scores] def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index", is_pgsql: bool = True, reset: bool = False) -> RE_FAISS: embeddings = get_embding(_path=embedding_model_name) docstore1: PgSqlDocstore = None if is_pgsql: if info and "host" in info and "dbname" in info and "username" in info and "password" in info: docstore1 = PgSqlDocstore(info, reset=reset) else: docstore1 = InMemorySecondaryDocstore() if not path.exists(store_path): os.makedirs(store_path, exist_ok=True) if store_path is None or len(store_path) <= 0 or not path.exists( path.join(store_path, index_name + ".faiss")) or reset: print("create new faiss") index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0])) # 根据embeddings向量维度设置 return RE_FAISS(embedding_function=embeddings.client.encode, index=index, docstore=docstore1, index_to_docstore_id={}) else: print("load_local faiss") _faiss = RE_FAISS.load_local(folder_path=store_path, index_name=index_name, embeddings=embeddings, allow_dangerous_deserialization=True) if docstore1 and is_pgsql: # 如果外部参数调整,更新docstore _faiss.docstore = docstore1 return _faiss class VectorStore_FAISS(FAISS): def __init__(self, embedding_model_name: str, store_path: str, index_name: str = "index", info: dict = None, is_pgsql: bool = True, show_number=5, threshold=0.8, reset: bool = False, doc_callback: DocumentCallback = DefaultDocumentCallback()): self.info = info self.embedding_model_name = embedding_model_name self.store_path = path.join(store_path, index_name) if not path.exists(self.store_path): os.makedirs(self.store_path, exist_ok=True) self.index_name = index_name self.show_number = show_number self.search_number = self.show_number * 3 self.threshold = threshold self._faiss = getFAISS(self.embedding_model_name, self.store_path, info=info, index_name=self.index_name, is_pgsql=is_pgsql, reset=reset) self.doc_callback = doc_callback def get_text_similarity_with_score(self, text: str, **kwargs): score_threshold = (1 - self.threshold) * math.sqrt(2) docs = self._faiss.similarity_search_with_score(query=text, k=self.search_number, score_threshold=score_threshold, doc_callback=self.doc_callback, **kwargs) return [doc for doc, similarity in docs][:self.show_number] def get_text_similarity(self, text: str, **kwargs): docs = self._faiss.similarity_search(query=text, k=self.search_number, doc_callback=self.doc_callback, **kwargs) return docs[:self.show_number] # #去重,并保留metadate # def _tuple_deduplication(self, tuple_input:List[Document]) -> List[Document]: # deduplicated_dict = OrderedDict() # for doc in tuple_input: # page_content = doc.page_content # metadata = doc.metadata # if page_content not in deduplicated_dict: # deduplicated_dict[page_content] = metadata # deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()] # return deduplicated_documents @staticmethod def join_document(docs: List[Document]) -> str: print(docs) return "".join([doc.page_content for doc in docs]) @staticmethod def get_local_doc(docs: List[Document]): ans = [] for doc in docs: ans.append({"page_content": doc.page_content, "page_number": doc.metadata["page_number"], "filename": doc.metadata["filename"]}) return ans # def _join_document_location(self, docs:List[Document]) -> str: # 持久化到本地 def _save_local(self): self._faiss.save_local(folder_path=self.store_path, index_name=self.index_name) # 添加文档 # Document { # page_content 段落 # metadata { # page 页码 # } # } def _add_documents(self, new_docs: List[Document], need_split: bool = True, pattern: str = r'[?。;\n]'): list_of_documents: List[Document] = [] if self.doc_callback: new_docs = self.doc_callback.before_store(self._faiss.docstore, new_docs) if need_split: for doc in new_docs: words_list = re.split(pattern, doc.page_content) # 去掉重复项 words_list = set(words_list) words_list = [str(words) for words in words_list] for words in words_list: if not words.strip() == '': metadata = copy.deepcopy(doc.metadata) metadata["paragraph"] = doc.page_content list_of_documents.append(Document(page_content=words, metadata=metadata)) else: list_of_documents = new_docs self._faiss.add_documents(list_of_documents) def _add_documents_from_dir(self, filepaths=None, load_kwargs=None): if load_kwargs is None: load_kwargs = {"mode": "paged"} if filepaths is None: filepaths = [] self._add_documents(load.loads(filepaths, **load_kwargs)) def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever: """ Return VectorStoreRetriever initialized from this VectorStore. Args: search_type (Optional[str]): Defines the type of search that the Retriever should perform. Can be "similarity" (default), "mmr", or "similarity_score_threshold". search_kwargs (Optional[Dict]): Keyword arguments to pass to the search function. Can include things like: k: Amount of documents to return (Default: 4) score_threshold: Minimum relevance threshold for similarity_score_threshold fetch_k: Amount of documents to pass to MMR algorithm (Default: 20) lambda_mult: Diversity of results returned by MMR; 1 for minimum diversity and 0 for maximum. (Default: 0.5) filter: Filter by document metadata Returns: VectorStoreRetriever: Retriever class for VectorStore. Examples: . code-block:: python # Retrieve more documents with higher diversity # Useful if your dataset has many similar documents docsearch.as_retriever( search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25} ) # Fetch more documents for the MMR algorithm to consider # But only return the top 5 docsearch.as_retriever( search_type="mmr", search_kwargs={'k': 5, 'fetch_k': 50} ) # Only retrieve documents that have a relevance score # Above a certain threshold docsearch.as_retriever( search_type="similarity_score_threshold", search_kwargs={'score_threshold': 0.8} ) # Only get the single most similar document from the dataset docsearch.as_retriever(search_kwargs={'k': 1}) # Use a filter to only retrieve documents from a specific paper docsearch.as_retriever( search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}} ) """ if not kwargs or "similarity_score_threshold" != kwargs["search_type"]: default_kwargs = {'k': self.show_number} if "search_kwargs" in kwargs: default_kwargs.update(kwargs["search_kwargs"]) kwargs["search_kwargs"] = default_kwargs elif "similarity_score_threshold" == kwargs["search_type"]: default_kwargs = {'score_threshold': self.threshold, 'k': self.show_number} if "search_kwargs" in kwargs: default_kwargs.update(kwargs["search_kwargs"]) kwargs["search_kwargs"] = default_kwargs kwargs["search_kwargs"]["doc_callback"] = self.doc_callback tags = kwargs.pop("tags", None) or [] tags.extend(self._faiss._get_retriever_tags()) print(kwargs) return VectorStoreRetriever_FAISS(vectorstore=self._faiss, **kwargs, tags=tags) class VectorStoreRetriever_FAISS(VectorStoreRetriever): search_k = 5 def __init__(self, **kwargs): super().__init__(**kwargs) if "k" in self.search_kwargs: self.search_k = self.search_kwargs["k"] self.search_kwargs["k"] = self.search_k * 2 def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: docs = super()._get_relevant_documents(query=query, run_manager=run_manager) return docs[:self.search_k] async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: docs = super()._aget_relevant_documents(query=query, run_manager=run_manager) return docs[:self.search_k]