Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
ec2eeec0
Commit
ec2eeec0
authored
Jul 01, 2024
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
rerank 模型实例只加载一次
parent
eacf477d
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
6 deletions
+27
-6
rerank.py
src/server/rerank.py
+27
-6
No files found.
src/server/rerank.py
View file @
ec2eeec0
from
__future__
import
annotations
from
__future__
import
annotations
import
threading
from
typing
import
Dict
,
Optional
,
Sequence
from
typing
import
Dict
,
Optional
,
Sequence
from
langchain_core.documents
import
Document
from
langchain_core.documents
import
Document
from
langchain.pydantic_v1
import
Extra
,
root_validator
from
langchain.pydantic_v1
import
Extra
,
root_validator
from
langchain.callbacks.manager
import
Callbacks
from
langchain.callbacks.manager
import
Callbacks
from
langchain.retrievers.document_compressors.base
import
BaseDocumentCompressor
from
langchain.retrievers.document_compressors.base
import
BaseDocumentCompressor
...
@@ -13,12 +15,18 @@ class BgeRerank(BaseDocumentCompressor):
...
@@ -13,12 +15,18 @@ class BgeRerank(BaseDocumentCompressor):
"""Model name to use for reranking."""
"""Model name to use for reranking."""
top_n
:
int
=
10
top_n
:
int
=
10
"""Number of documents to return."""
"""Number of documents to return."""
model
:
CrossEncoder
=
None
_
model
:
CrossEncoder
=
None
"""CrossEncoder instance to use for reranking."""
"""CrossEncoder instance to use for reranking."""
_lock
=
threading
.
Lock
()
"""Lock to ensure thread safety."""
def
__init__
(
self
,
model_name
:
str
,
top_n
:
int
=
10
):
def
__init__
(
self
,
model_name
:
str
,
top_n
:
int
=
10
):
super
()
.
__init__
(
model_name
=
model_name
,
top_n
=
top_n
)
super
()
.
__init__
(
model_name
=
model_name
,
top_n
=
top_n
)
self
.
model
=
CrossEncoder
(
model_name
)
if
not
BgeRerank
.
_model
:
with
BgeRerank
.
_lock
:
if
not
BgeRerank
.
_model
:
BgeRerank
.
_model
=
CrossEncoder
(
model_name
)
self
.
model
=
BgeRerank
.
_model
def
bge_rerank
(
self
,
query
,
docs
):
def
bge_rerank
(
self
,
query
,
docs
):
model_inputs
=
[[
query
,
doc
]
for
doc
in
docs
]
model_inputs
=
[[
query
,
doc
]
for
doc
in
docs
]
...
@@ -30,7 +38,7 @@ class BgeRerank(BaseDocumentCompressor):
...
@@ -30,7 +38,7 @@ class BgeRerank(BaseDocumentCompressor):
class
Config
:
class
Config
:
"""Configuration for this pydantic object."""
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
extra
=
Extra
.
allow
arbitrary_types_allowed
=
True
arbitrary_types_allowed
=
True
def
compress_documents
(
def
compress_documents
(
...
@@ -60,4 +68,17 @@ class BgeRerank(BaseDocumentCompressor):
...
@@ -60,4 +68,17 @@ class BgeRerank(BaseDocumentCompressor):
doc
=
doc_list
[
r
[
0
]]
doc
=
doc_list
[
r
[
0
]]
doc
.
metadata
[
"relevance_score"
]
=
r
[
1
]
doc
.
metadata
[
"relevance_score"
]
=
r
[
1
]
final_results
.
append
(
doc
)
final_results
.
append
(
doc
)
return
final_results
return
final_results
\ No newline at end of file
from
abc
import
ABC
import
numpy
as
np
def
sigmoid
(
x
):
return
1
/
(
1
+
np
.
exp
(
-
x
))
class
Base
(
ABC
):
def
__init__
(
self
,
key
,
model_name
):
pass
def
similarity
(
self
,
query
:
str
,
texts
:
list
):
raise
NotImplementedError
(
"Please implement encode method!"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment