from __future__ import annotations import threading from typing import Dict, Optional, Sequence from langchain_core.documents import Document from langchain.pydantic_v1 import Extra, root_validator from langchain.callbacks.manager import Callbacks from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from sentence_transformers import CrossEncoder class BgeRerank(BaseDocumentCompressor): model_name:str = 'bge_reranker_large_model_path' """Model name to use for reranking.""" top_n: int = 10 """Number of documents to return.""" _model:CrossEncoder = None """CrossEncoder instance to use for reranking.""" _lock = threading.Lock() """Lock to ensure thread safety.""" def __init__(self, model_name: str, top_n: int = 10): super().__init__(model_name=model_name, top_n=top_n) if not BgeRerank._model: with BgeRerank._lock: if not BgeRerank._model: BgeRerank._model = CrossEncoder(model_name) self.model = BgeRerank._model def bge_rerank(self,query,docs): model_inputs = [[query, doc] for doc in docs] scores = self.model.predict(model_inputs) results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) return results[:self.top_n] class Config: """Configuration for this pydantic object.""" extra = Extra.allow arbitrary_types_allowed = True def compress_documents( self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """ Compress documents using BAAI/bge-reranker models. Args: documents: A sequence of documents to compress. query: The query to use for compressing the documents. callbacks: Callbacks to run during the compression process. Returns: A sequence of compressed documents. """ if len(documents) == 0: # to avoid empty api call return [] doc_list = list(documents) _docs = [d.page_content for d in doc_list] results = self.bge_rerank(query, _docs) final_results = [] for r in results: doc = doc_list[r[0]] doc.metadata["relevance_score"] = r[1] final_results.append(doc) return final_results from abc import ABC import numpy as np def sigmoid(x): return 1 / (1 + np.exp(-x)) class Base(ABC): def __init__(self, key, model_name): pass def similarity(self, query: str, texts: list): raise NotImplementedError("Please implement encode method!")