Commit 4f95b54c by 文靖昊

修改重排bug,新增历史对话功能

parent 582deb2e
# ============================= # =============================
# 资料存储数据库配置 # 资料存储数据库配置
# ============================= # =============================
VEC_DB_HOST = '192.168.10.93' VEC_DB_HOST = '192.168.10.189'
VEC_DB_DBNAME = 'lae' VEC_DB_DBNAME = 'lae'
VEC_DB_USER = 'postgres' VEC_DB_USER = 'postgres'
VEC_DB_PASSWORD = '111111' VEC_DB_PASSWORD = '111111'
...@@ -10,7 +10,7 @@ VEC_DB_PORT = '5433' ...@@ -10,7 +10,7 @@ VEC_DB_PORT = '5433'
# ============================= # =============================
# 聊天相关数据库配置 # 聊天相关数据库配置
# ============================= # =============================
CHAT_DB_HOST = '192.168.10.93' CHAT_DB_HOST = '192.168.10.189'
CHAT_DB_DBNAME = 'lae' CHAT_DB_DBNAME = 'lae'
CHAT_DB_USER = 'postgres' CHAT_DB_USER = 'postgres'
CHAT_DB_PASSWORD = '111111' CHAT_DB_PASSWORD = '111111'
...@@ -19,12 +19,12 @@ CHAT_DB_PORT = '5433' ...@@ -19,12 +19,12 @@ CHAT_DB_PORT = '5433'
# ============================= # =============================
# 向量化模型路径配置 # 向量化模型路径配置
# ============================= # =============================
EMBEEDING_MODEL_PATH = '/app/bge-large-zh-v1.5' EMBEEDING_MODEL_PATH = 'D:\\work\\py\\LAE\\bge-large-zh-v1.5'
# ============================= # =============================
# 重排序模型路径配置 # 重排序模型路径配置
# ============================= # =============================
RERANK_MODEL_PATH = '/app/bge-reranker-large' RERANK_MODEL_PATH = 'D:\\work\\py\\LAE\\bge-reranker-large'
# RERANK_MODEL_PATH = 'BAAI/bge-reranker-large' # RERANK_MODEL_PATH = 'BAAI/bge-reranker-large'
# ============================= # =============================
...@@ -35,19 +35,19 @@ LLM_SERVER_URL = '192.168.10.102:8002' ...@@ -35,19 +35,19 @@ LLM_SERVER_URL = '192.168.10.102:8002'
# ============================= # =============================
# FAISS相似性查找配置 # FAISS相似性查找配置
# ============================= # =============================
SIMILARITY_SHOW_NUMBER = 5 SIMILARITY_SHOW_NUMBER = 10
SIMILARITY_THRESHOLD = 0.8 SIMILARITY_THRESHOLD = 0.8
# ============================= # =============================
# FAISS向量库文件存储路径配置 # FAISS向量库文件存储路径配置
# ============================= # =============================
FAISS_STORE_PATH = '/app/faiss' FAISS_STORE_PATH = 'D:\\work\\py\\LAE\\faiss'
INDEX_NAME = 'know' INDEX_NAME = 'know'
# ============================= # =============================
# 知识相关资料配置 # 知识相关资料配置
# ============================= # =============================
KNOWLEDGE_PATH = '/app/lae_data' KNOWLEDGE_PATH = 'D:\\work\\py\\LAE\\testdoc'
# ============================= # =============================
# gradio服务相关配置 # gradio服务相关配置
...@@ -64,6 +64,12 @@ prompt1 = """''' ...@@ -64,6 +64,12 @@ prompt1 = """'''
请你根据上述已知资料回答下面的问题,问题如下: 请你根据上述已知资料回答下面的问题,问题如下:
{question}""" {question}"""
prompt_enhancement_history_template = """{history}
上面是之前的对话,下面是可参考的内容,参考内容中如果和问题不符合,可以不用参考。
{context}
请结合上述内容回答以下问题,不要提无关内容:
{question}
"""
# ============================= # =============================
# NLP_BERT模型路径配置 # NLP_BERT模型路径配置
# ============================= # =============================
......
...@@ -32,7 +32,7 @@ from src.config.consts import ( ...@@ -32,7 +32,7 @@ from src.config.consts import (
VEC_DB_USER, VEC_DB_USER,
VEC_DB_DBNAME, VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER, SIMILARITY_SHOW_NUMBER,
prompt1 prompt_enhancement_history_template
) )
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
...@@ -62,8 +62,8 @@ base_llm = ChatOpenAI( ...@@ -62,8 +62,8 @@ base_llm = ChatOpenAI(
model_name='Qwen2-7B', model_name='Qwen2-7B',
verbose=True verbose=True
) )
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm, my_chat = QA(PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss) {"temperature": 0.9}, ['history','context', 'question'], _db=c_db, _faiss_db=vecstore_faiss,rerank=True)
@app.post('/api/login') @app.post('/api/login')
def login(phone_request: PhoneLoginRequest): def login(phone_request: PhoneLoginRequest):
...@@ -156,10 +156,13 @@ def question(chat_request: ChatRequest, token: str = Header(None)): ...@@ -156,10 +156,13 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
question = chat_request.question question = chat_request.question
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
history = [] history = []
if session_id =="": if session_id !="":
history = crud.get_last_history(str(session_id)) history = crud.get_last_history(str(session_id))
# answer = my_chat.chat(question) # answer = my_chat.chat(question)
answer, docs = my_chat.chat(question, with_similarity=True) prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history(question,history=prompt, with_similarity=True)
docs_json = [] docs_json = []
for d in docs: for d in docs:
j ={} j ={}
...@@ -198,7 +201,10 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): ...@@ -198,7 +201,10 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
last_turn_id = crud.get_last_turn_num(str(session_id)) last_turn_id = crud.get_last_turn_num(str(session_id))
history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id) history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
answer, docs = my_chat.chat(question, with_similarity=True) prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history(question, history=prompt, with_similarity=True)
docs_json = [] docs_json = []
for d in docs: for d in docs:
j = {} j = {}
......
...@@ -227,7 +227,6 @@ class VectorStore_FAISS(FAISS): ...@@ -227,7 +227,6 @@ class VectorStore_FAISS(FAISS):
@staticmethod @staticmethod
def join_document(docs: List[Document]) -> str: def join_document(docs: List[Document]) -> str:
print(docs)
return "".join([doc.page_content for doc in docs]) return "".join([doc.page_content for doc in docs])
@staticmethod @staticmethod
......
...@@ -35,7 +35,7 @@ class TxtVector: ...@@ -35,7 +35,7 @@ class TxtVector:
query = f"SELECT paragraph_id,text FROM vec_txt WHERE vector_id = %s" query = f"SELECT paragraph_id,text FROM vec_txt WHERE vector_id = %s"
self.db.execute_args(query, [search]) self.db.execute_args(query, [search])
answer = self.db.fetchall() answer = self.db.fetchall()
print(answer) # print(answer)
return answer[0] return answer[0]
def create_table(self): def create_table(self):
......
...@@ -7,11 +7,15 @@ class GetSimilarity: ...@@ -7,11 +7,15 @@ class GetSimilarity:
self.faiss_db = _faiss_db self.faiss_db = _faiss_db
self.similarity_docs = self.faiss_db.get_text_similarity(self.question) self.similarity_docs = self.faiss_db.get_text_similarity(self.question)
self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs) self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
self.rerank_docs = ""
def get_rerank(self,reranker:BgeRerank ,top_k = 5): def get_rerank(self,reranker:BgeRerank ,top_k = 5):
rerank_docs = reranker.compress_documents(self.similarity_docs,self.question) rerank_docs = reranker.compress_documents(self.similarity_docs,self.question)
self.rerank_docs = rerank_docs d_list = []
return self.faiss_db.join_document([d[1] for d in rerank_docs[:top_k]]) for d in rerank_docs[:top_k]:
d_list.append(d)
self.rerank_docs = rerank_docs[:top_k]
return self.faiss_db.join_document(d_list)
def get_similarity_doc(self): def get_similarity_doc(self):
return self.similarity_doc_txt return self.similarity_doc_txt
......
...@@ -37,8 +37,14 @@ prompt1 = """''' ...@@ -37,8 +37,14 @@ prompt1 = """'''
''' '''
请你根据上述已知资料回答下面的问题,问题如下: 请你根据上述已知资料回答下面的问题,问题如下:
{question}""" {question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1) PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
# 预设的安全响应 # 预设的安全响应
SAFE_RESPONSE = "您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。" SAFE_RESPONSE = "您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。"
BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"] BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"]
...@@ -90,10 +96,6 @@ class QA: ...@@ -90,10 +96,6 @@ class QA:
self.cur_similarity = similarity.get_similarity_doc() self.cur_similarity = similarity.get_similarity_doc()
similarity_docs = similarity.get_similarity_docs() similarity_docs = similarity.get_similarity_docs()
rerank_docs = similarity.get_rerank_docs() rerank_docs = similarity.get_rerank_docs()
print("============== similarity ==============")
print(similarity_docs)
print("============== rerank ==============")
print(rerank_docs)
self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion) self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
if not _question: if not _question:
return "" return ""
...@@ -103,9 +105,41 @@ class QA: ...@@ -103,9 +105,41 @@ class QA:
self.update_history() self.update_history()
if with_similarity: if with_similarity:
return self.cur_answer, similarity_docs if self.rerank:
return self.cur_answer, rerank_docs
else:
return self.cur_answer, similarity_docs
return self.cur_answer
def chat_with_history(self, _question,history,with_similarity=False):
self.cur_oquestion = _question
if self.contains_blocked_keywords(_question):
self.cur_answer = SAFE_RESPONSE
else:
# self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
if self.rerank:
self.cur_similarity = similarity.get_rerank(self.rerank_model)
else:
self.cur_similarity = similarity.get_similarity_doc()
similarity_docs = similarity.get_similarity_docs()
rerank_docs = similarity.get_rerank_docs()
self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion)
if not _question:
return ""
self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion)
if self.contains_blocked_keywords(self.cur_answer):
self.cur_answer = SAFE_RESPONSE
self.update_history()
if with_similarity:
if self.rerank:
return self.cur_answer, rerank_docs
else:
return self.cur_answer, similarity_docs
return self.cur_answer return self.cur_answer
# 异步输出,逐渐输出答案 # 异步输出,逐渐输出答案
async def async_chat(self, _question): async def async_chat(self, _question):
self.cur_oquestion = _question self.cur_oquestion = _question
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment