Commit 7aad8ddc by tinywell

chore: Update chat function to include similarity with docs

parent 9a06d464
...@@ -161,7 +161,9 @@ def question(): ...@@ -161,7 +161,9 @@ def question():
if session_id != "": if session_id != "":
history = crud.get_last_history(str(session_id)) history = crud.get_last_history(str(session_id))
print(history) print(history)
answer = my_chat.chat(question) # answer = my_chat.chat(question)
answer, docs = my_chat.chat(question, with_similarity=True)
print(docs)
# answer = "test Answer" # answer = "test Answer"
if session_id == "": if session_id == "":
session_id = crud.create_chat(token, '\t\t', '0') session_id = crud.create_chat(token, '\t\t', '0')
......
...@@ -5,8 +5,12 @@ class GetSimilarity: ...@@ -5,8 +5,12 @@ class GetSimilarity:
def __init__(self, _question, _faiss_db: VectorStore_FAISS): def __init__(self, _question, _faiss_db: VectorStore_FAISS):
self.question = _question self.question = _question
self.faiss_db = _faiss_db self.faiss_db = _faiss_db
self.similarity_doc = self.faiss_db.join_document(self.faiss_db.get_text_similarity(self.question)) self.similarity_docs = self.faiss_db.get_text_similarity(self.question)
self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
def get_similarity_doc(self): def get_similarity_doc(self):
return self.similarity_doc return self.similarity_doc_txt
def get_similarity_docs(self):
return self.similarity_docs
...@@ -67,13 +67,19 @@ class QA: ...@@ -67,13 +67,19 @@ class QA:
def get_similarity(self, _aquestion): def get_similarity(self, _aquestion):
return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc() return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
def get_similarity_origin(self, _aquestion):
return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db)
# 一次性直接给出所有的答案 # 一次性直接给出所有的答案
def chat(self, _question): def chat(self, _question,with_similarity=False):
self.cur_oquestion = _question self.cur_oquestion = _question
if self.contains_blocked_keywords(_question): if self.contains_blocked_keywords(_question):
self.cur_answer = SAFE_RESPONSE self.cur_answer = SAFE_RESPONSE
else: else:
self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion) # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
self.cur_similarity = similarity.get_similarity_doc()
similarity_docs = similarity.get_similarity_docs()
self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion) self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
if not _question: if not _question:
return "" return ""
...@@ -82,6 +88,8 @@ class QA: ...@@ -82,6 +88,8 @@ class QA:
self.cur_answer = SAFE_RESPONSE self.cur_answer = SAFE_RESPONSE
self.update_history() self.update_history()
if with_similarity:
return self.cur_answer, similarity_docs
return self.cur_answer return self.cur_answer
# 异步输出,逐渐输出答案 # 异步输出,逐渐输出答案
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment