Commit b538b447 by tinywell

引入相似文档重排

parent 59ed5f14
......@@ -22,6 +22,12 @@ CHAT_DB_PORT = '5433'
EMBEEDING_MODEL_PATH = '/app/bge-large-zh-v1.5'
# =============================
# 重排序模型路径配置
# =============================
RERANK_MODEL_PATH = '/app/bge-reranker-large'
# RERANK_MODEL_PATH = 'BAAI/bge-reranker-large'
# =============================
# 模型服务URL配置
# =============================
LLM_SERVER_URL = '192.168.10.102:8002'
......
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from .rerank import BgeRerank
class GetSimilarity:
def __init__(self, _question, _faiss_db: VectorStore_FAISS):
......@@ -8,9 +8,17 @@ class GetSimilarity:
self.similarity_docs = self.faiss_db.get_text_similarity(self.question)
self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
def get_rerank(self,reranker:BgeRerank ,top_k = 5):
rerank_docs = reranker.bge_rerank(self.question,self.similarity_docs)
self.rerank_docs = rerank_docs
return self.faiss_db.join_document([d[1] for d in rerank_docs[:top_k]])
def get_similarity_doc(self):
return self.similarity_doc_txt
def get_similarity_docs(self):
return self.similarity_docs
def get_rerank_docs(self):
return self.rerank_docs
......@@ -19,6 +19,7 @@ from src.config.consts import (
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD,
EMBEEDING_MODEL_PATH,
RERANK_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
......@@ -28,6 +29,7 @@ from src.config.consts import (
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER
)
from .rerank import BgeRerank
sys.path.append("../..")
prompt1 = """'''
......@@ -43,7 +45,7 @@ BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"]
class QA:
def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db):
def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db,rerank:bool=False):
self.prompt = _prompt
self.base_llm = _base_llm
self.llm_kwargs = _llm_kwargs
......@@ -58,6 +60,10 @@ class QA:
self.cur_question = ""
self.cur_similarity = ""
self.cur_oquestion = ""
self.rerank = rerank
if rerank:
self.rerank_model = BgeRerank(RERANK_MODEL_PATH)
# 检查是否包含敏感信息
def contains_blocked_keywords(self, text):
......@@ -78,8 +84,16 @@ class QA:
else:
# self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
if self.rerank:
self.cur_similarity = similarity.get_rerank(self.rerank_model)
else:
self.cur_similarity = similarity.get_similarity_doc()
similarity_docs = similarity.get_similarity_docs()
rerank_docs = similarity.get_rerank_docs()
print("============== similarity ==============")
print(similarity_docs)
print("============== rerank ==============")
print(rerank_docs)
self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
if not _question:
return ""
......
from __future__ import annotations
from typing import Dict, Optional, Sequence
from langchain_core.documents import Document
from langchain.pydantic_v1 import Extra, root_validator
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from sentence_transformers import CrossEncoder
class BgeRerank(BaseDocumentCompressor):
model_name:str = 'bge_reranker_large_model_path'
"""Model name to use for reranking."""
top_n: int = 10
"""Number of documents to return."""
model:CrossEncoder
"""CrossEncoder instance to use for reranking."""
def __init__(self, model_name: str, top_n: int = 10):
self.model_name = model_name
self.model = CrossEncoder(model_name)
self.top_n = top_n
def bge_rerank(self,query,docs):
model_inputs = [[query, doc] for doc in docs]
scores = self.model.predict(model_inputs)
results = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
return results[:self.top_n]
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using BAAI/bge-reranker models.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
if len(documents) == 0: # to avoid empty api call
return []
doc_list = list(documents)
_docs = [d.page_content for d in doc_list]
results = self.bge_rerank(query, _docs)
final_results = []
for r in results:
doc = doc_list[r[0]]
doc.metadata["relevance_score"] = r[1]
final_results.append(doc)
return final_results
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment