2.13 KB
Newer Older
tinywell committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from __future__ import annotations
from typing import Dict, Optional, Sequence
from langchain_core.documents import Document
from langchain.pydantic_v1 import Extra, root_validator

from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor

from sentence_transformers import CrossEncoder

class BgeRerank(BaseDocumentCompressor):
    model_name:str = 'bge_reranker_large_model_path'  
    """Model name to use for reranking."""    
    top_n: int = 10   
    """Number of documents to return."""
    model:CrossEncoder = None
tinywell committed
17 18 19
    """CrossEncoder instance to use for reranking."""

    def __init__(self, model_name: str, top_n: int = 10):
        super().__init__(model_name=model_name, top_n=top_n)
tinywell committed
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
        self.model = CrossEncoder(model_name)

    def bge_rerank(self,query,docs):
        model_inputs =  [[query, doc] for doc in docs]
        scores = self.model.predict(model_inputs)
        results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
        return results[:self.top_n]

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    def compress_documents(
        documents: Sequence[Document],
        query: str,
        callbacks: Optional[Callbacks] = None,
    ) -> Sequence[Document]:
        Compress documents using BAAI/bge-reranker models.

            documents: A sequence of documents to compress.
            query: The query to use for compressing the documents.
            callbacks: Callbacks to run during the compression process.

            A sequence of compressed documents.
        if len(documents) == 0:  # to avoid empty api call
            return []
        doc_list = list(documents)
        _docs = [d.page_content for d in doc_list]
        results = self.bge_rerank(query, _docs)
        final_results = []
        for r in results:
            doc = doc_list[r[0]]
            doc.metadata["relevance_score"] = r[1]
        return final_results