Commit ec2eeec0 by tinywell

rerank 模型实例只加载一次

parent eacf477d
from __future__ import annotations from __future__ import annotations
import threading
from typing import Dict, Optional, Sequence from typing import Dict, Optional, Sequence
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain.pydantic_v1 import Extra, root_validator from langchain.pydantic_v1 import Extra, root_validator
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
...@@ -13,12 +15,18 @@ class BgeRerank(BaseDocumentCompressor): ...@@ -13,12 +15,18 @@ class BgeRerank(BaseDocumentCompressor):
"""Model name to use for reranking.""" """Model name to use for reranking."""
top_n: int = 10 top_n: int = 10
"""Number of documents to return.""" """Number of documents to return."""
model:CrossEncoder = None _model:CrossEncoder = None
"""CrossEncoder instance to use for reranking.""" """CrossEncoder instance to use for reranking."""
_lock = threading.Lock()
"""Lock to ensure thread safety."""
def __init__(self, model_name: str, top_n: int = 10): def __init__(self, model_name: str, top_n: int = 10):
super().__init__(model_name=model_name, top_n=top_n) super().__init__(model_name=model_name, top_n=top_n)
self.model = CrossEncoder(model_name) if not BgeRerank._model:
with BgeRerank._lock:
if not BgeRerank._model:
BgeRerank._model = CrossEncoder(model_name)
self.model = BgeRerank._model
def bge_rerank(self,query,docs): def bge_rerank(self,query,docs):
model_inputs = [[query, doc] for doc in docs] model_inputs = [[query, doc] for doc in docs]
...@@ -30,7 +38,7 @@ class BgeRerank(BaseDocumentCompressor): ...@@ -30,7 +38,7 @@ class BgeRerank(BaseDocumentCompressor):
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.forbid extra = Extra.allow
arbitrary_types_allowed = True arbitrary_types_allowed = True
def compress_documents( def compress_documents(
...@@ -60,4 +68,17 @@ class BgeRerank(BaseDocumentCompressor): ...@@ -60,4 +68,17 @@ class BgeRerank(BaseDocumentCompressor):
doc = doc_list[r[0]] doc = doc_list[r[0]]
doc.metadata["relevance_score"] = r[1] doc.metadata["relevance_score"] = r[1]
final_results.append(doc) final_results.append(doc)
return final_results return final_results
\ No newline at end of file
from abc import ABC
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
class Base(ABC):
def __init__(self, key, model_name):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment