Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
e85ee1bd
Commit
e85ee1bd
authored
Jul 01, 2024
by
文靖昊
Browse files
Options
Browse Files
Download
Plain Diff
Merge remote-tracking branch 'origin/geo' into geo
parents
1d7bb00f
ec2eeec0
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
4 deletions
+26
-4
rerank.py
src/server/rerank.py
+26
-4
No files found.
src/server/rerank.py
View file @
e85ee1bd
from
__future__
import
annotations
import
threading
from
typing
import
Dict
,
Optional
,
Sequence
from
langchain_core.documents
import
Document
from
langchain.pydantic_v1
import
Extra
,
root_validator
from
langchain.callbacks.manager
import
Callbacks
from
langchain.retrievers.document_compressors.base
import
BaseDocumentCompressor
...
...
@@ -13,12 +15,18 @@ class BgeRerank(BaseDocumentCompressor):
"""Model name to use for reranking."""
top_n
:
int
=
10
"""Number of documents to return."""
model
:
CrossEncoder
=
None
_
model
:
CrossEncoder
=
None
"""CrossEncoder instance to use for reranking."""
_lock
=
threading
.
Lock
()
"""Lock to ensure thread safety."""
def
__init__
(
self
,
model_name
:
str
,
top_n
:
int
=
10
):
super
()
.
__init__
(
model_name
=
model_name
,
top_n
=
top_n
)
self
.
model
=
CrossEncoder
(
model_name
)
if
not
BgeRerank
.
_model
:
with
BgeRerank
.
_lock
:
if
not
BgeRerank
.
_model
:
BgeRerank
.
_model
=
CrossEncoder
(
model_name
)
self
.
model
=
BgeRerank
.
_model
def
bge_rerank
(
self
,
query
,
docs
):
model_inputs
=
[[
query
,
doc
]
for
doc
in
docs
]
...
...
@@ -30,7 +38,7 @@ class BgeRerank(BaseDocumentCompressor):
class
Config
:
"""Configuration for this pydantic object."""
extra
=
Extra
.
forbid
extra
=
Extra
.
allow
arbitrary_types_allowed
=
True
def
compress_documents
(
...
...
@@ -61,3 +69,16 @@ class BgeRerank(BaseDocumentCompressor):
doc
.
metadata
[
"relevance_score"
]
=
r
[
1
]
final_results
.
append
(
doc
)
return
final_results
from
abc
import
ABC
import
numpy
as
np
def
sigmoid
(
x
):
return
1
/
(
1
+
np
.
exp
(
-
x
))
class
Base
(
ABC
):
def
__init__
(
self
,
key
,
model_name
):
pass
def
similarity
(
self
,
query
:
str
,
texts
:
list
):
raise
NotImplementedError
(
"Please implement encode method!"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment