Commit e85ee1bd by 文靖昊

Merge remote-tracking branch 'origin/geo' into geo

parents 1d7bb00f ec2eeec0
from __future__ import annotations
import threading
from typing import Dict, Optional, Sequence
from langchain_core.documents import Document
from langchain.pydantic_v1 import Extra, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
......@@ -13,12 +15,18 @@ class BgeRerank(BaseDocumentCompressor):
"""Model name to use for reranking."""
top_n: int = 10
"""Number of documents to return."""
model:CrossEncoder = None
_model:CrossEncoder = None
"""CrossEncoder instance to use for reranking."""
_lock = threading.Lock()
"""Lock to ensure thread safety."""
def __init__(self, model_name: str, top_n: int = 10):
super().__init__(model_name=model_name, top_n=top_n)
self.model = CrossEncoder(model_name)
if not BgeRerank._model:
with BgeRerank._lock:
if not BgeRerank._model:
BgeRerank._model = CrossEncoder(model_name)
self.model = BgeRerank._model
def bge_rerank(self,query,docs):
model_inputs = [[query, doc] for doc in docs]
......@@ -30,7 +38,7 @@ class BgeRerank(BaseDocumentCompressor):
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
extra = Extra.allow
arbitrary_types_allowed = True
def compress_documents(
......@@ -60,4 +68,17 @@ class BgeRerank(BaseDocumentCompressor):
doc = doc_list[r[0]]
doc.metadata["relevance_score"] = r[1]
final_results.append(doc)
return final_results
\ No newline at end of file
return final_results
from abc import ABC
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
class Base(ABC):
def __init__(self, key, model_name):
pass
def similarity(self, query: str, texts: list):
raise NotImplementedError("Please implement encode method!")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment