Commit 4f95b54c by 文靖昊

修改重排bug,新增历史对话功能

parent 582deb2e
# =============================
# 资料存储数据库配置
# =============================
VEC_DB_HOST = '192.168.10.93'
VEC_DB_HOST = '192.168.10.189'
VEC_DB_DBNAME = 'lae'
VEC_DB_USER = 'postgres'
VEC_DB_PASSWORD = '111111'
......@@ -10,7 +10,7 @@ VEC_DB_PORT = '5433'
# =============================
# 聊天相关数据库配置
# =============================
CHAT_DB_HOST = '192.168.10.93'
CHAT_DB_HOST = '192.168.10.189'
CHAT_DB_DBNAME = 'lae'
CHAT_DB_USER = 'postgres'
CHAT_DB_PASSWORD = '111111'
......@@ -19,12 +19,12 @@ CHAT_DB_PORT = '5433'
# =============================
# 向量化模型路径配置
# =============================
EMBEEDING_MODEL_PATH = '/app/bge-large-zh-v1.5'
EMBEEDING_MODEL_PATH = 'D:\\work\\py\\LAE\\bge-large-zh-v1.5'
# =============================
# 重排序模型路径配置
# =============================
RERANK_MODEL_PATH = '/app/bge-reranker-large'
RERANK_MODEL_PATH = 'D:\\work\\py\\LAE\\bge-reranker-large'
# RERANK_MODEL_PATH = 'BAAI/bge-reranker-large'
# =============================
......@@ -35,19 +35,19 @@ LLM_SERVER_URL = '192.168.10.102:8002'
# =============================
# FAISS相似性查找配置
# =============================
SIMILARITY_SHOW_NUMBER = 5
SIMILARITY_SHOW_NUMBER = 10
SIMILARITY_THRESHOLD = 0.8
# =============================
# FAISS向量库文件存储路径配置
# =============================
FAISS_STORE_PATH = '/app/faiss'
FAISS_STORE_PATH = 'D:\\work\\py\\LAE\\faiss'
INDEX_NAME = 'know'
# =============================
# 知识相关资料配置
# =============================
KNOWLEDGE_PATH = '/app/lae_data'
KNOWLEDGE_PATH = 'D:\\work\\py\\LAE\\testdoc'
# =============================
# gradio服务相关配置
......@@ -64,6 +64,12 @@ prompt1 = """'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
prompt_enhancement_history_template = """{history}
上面是之前的对话,下面是可参考的内容,参考内容中如果和问题不符合,可以不用参考。
{context}
请结合上述内容回答以下问题,不要提无关内容:
{question}
"""
# =============================
# NLP_BERT模型路径配置
# =============================
......
......@@ -32,7 +32,7 @@ from src.config.consts import (
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER,
prompt1
prompt_enhancement_history_template
)
app = FastAPI()
app.add_middleware(
......@@ -62,8 +62,8 @@ base_llm = ChatOpenAI(
model_name='Qwen2-7B',
verbose=True
)
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
my_chat = QA(PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template), base_llm,
{"temperature": 0.9}, ['history','context', 'question'], _db=c_db, _faiss_db=vecstore_faiss,rerank=True)
@app.post('/api/login')
def login(phone_request: PhoneLoginRequest):
......@@ -156,10 +156,13 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
question = chat_request.question
crud = CRUD(_db=c_db)
history = []
if session_id =="":
if session_id !="":
history = crud.get_last_history(str(session_id))
# answer = my_chat.chat(question)
answer, docs = my_chat.chat(question, with_similarity=True)
prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history(question,history=prompt, with_similarity=True)
docs_json = []
for d in docs:
j ={}
......@@ -198,7 +201,10 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
crud = CRUD(_db=c_db)
last_turn_id = crud.get_last_turn_num(str(session_id))
history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
answer, docs = my_chat.chat(question, with_similarity=True)
prompt = ""
for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history(question, history=prompt, with_similarity=True)
docs_json = []
for d in docs:
j = {}
......
......@@ -227,7 +227,6 @@ class VectorStore_FAISS(FAISS):
@staticmethod
def join_document(docs: List[Document]) -> str:
print(docs)
return "".join([doc.page_content for doc in docs])
@staticmethod
......
......@@ -35,7 +35,7 @@ class TxtVector:
query = f"SELECT paragraph_id,text FROM vec_txt WHERE vector_id = %s"
self.db.execute_args(query, [search])
answer = self.db.fetchall()
print(answer)
# print(answer)
return answer[0]
def create_table(self):
......
......@@ -7,11 +7,15 @@ class GetSimilarity:
self.faiss_db = _faiss_db
self.similarity_docs = self.faiss_db.get_text_similarity(self.question)
self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
self.rerank_docs = ""
def get_rerank(self,reranker:BgeRerank ,top_k = 5):
rerank_docs = reranker.compress_documents(self.similarity_docs,self.question)
self.rerank_docs = rerank_docs
return self.faiss_db.join_document([d[1] for d in rerank_docs[:top_k]])
d_list = []
for d in rerank_docs[:top_k]:
d_list.append(d)
self.rerank_docs = rerank_docs[:top_k]
return self.faiss_db.join_document(d_list)
def get_similarity_doc(self):
return self.similarity_doc_txt
......
......@@ -37,8 +37,14 @@ prompt1 = """'''
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
# 预设的安全响应
SAFE_RESPONSE = "您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。"
BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"]
......@@ -90,10 +96,6 @@ class QA:
self.cur_similarity = similarity.get_similarity_doc()
similarity_docs = similarity.get_similarity_docs()
rerank_docs = similarity.get_rerank_docs()
print("============== similarity ==============")
print(similarity_docs)
print("============== rerank ==============")
print(rerank_docs)
self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
if not _question:
return ""
......@@ -103,9 +105,41 @@ class QA:
self.update_history()
if with_similarity:
return self.cur_answer, similarity_docs
if self.rerank:
return self.cur_answer, rerank_docs
else:
return self.cur_answer, similarity_docs
return self.cur_answer
def chat_with_history(self, _question,history,with_similarity=False):
self.cur_oquestion = _question
if self.contains_blocked_keywords(_question):
self.cur_answer = SAFE_RESPONSE
else:
# self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
if self.rerank:
self.cur_similarity = similarity.get_rerank(self.rerank_model)
else:
self.cur_similarity = similarity.get_similarity_doc()
similarity_docs = similarity.get_similarity_docs()
rerank_docs = similarity.get_rerank_docs()
self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion)
if not _question:
return ""
self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion)
if self.contains_blocked_keywords(self.cur_answer):
self.cur_answer = SAFE_RESPONSE
self.update_history()
if with_similarity:
if self.rerank:
return self.cur_answer, rerank_docs
else:
return self.cur_answer, similarity_docs
return self.cur_answer
# 异步输出,逐渐输出答案
async def async_chat(self, _question):
self.cur_oquestion = _question
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment