Commit 1d7bb00f by 文靖昊

输入问题优化

parent eacf477d
...@@ -9,9 +9,11 @@ import uvicorn ...@@ -9,9 +9,11 @@ import uvicorn
import json import json
from src.pgdb.chat.crud import CRUD from src.pgdb.chat.crud import CRUD
from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.get_similarity import QAExt
from src.server.qa import QA from src.server.qa import QA
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
import re
from src.controller.request import ( from src.controller.request import (
PhoneLoginRequest, PhoneLoginRequest,
ChatRequest, ChatRequest,
...@@ -62,6 +64,9 @@ base_llm = ChatOpenAI( ...@@ -62,6 +64,9 @@ base_llm = ChatOpenAI(
model_name='Qwen2-7B', model_name='Qwen2-7B',
verbose=True verbose=True
) )
ext = QAExt(base_llm)
my_chat = QA(PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template), base_llm, my_chat = QA(PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template), base_llm,
{"temperature": 0.9}, ['history','context', 'question'], _db=c_db, _faiss_db=vecstore_faiss,rerank=True) {"temperature": 0.9}, ['history','context', 'question'], _db=c_db, _faiss_db=vecstore_faiss,rerank=True)
...@@ -159,10 +164,13 @@ def question(chat_request: ChatRequest, token: str = Header(None)): ...@@ -159,10 +164,13 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
if session_id !="": if session_id !="":
history = crud.get_last_history(str(session_id)) history = crud.get_last_history(str(session_id))
# answer = my_chat.chat(question) # answer = my_chat.chat(question)
result = ext.extend_query(question, history)
matches = re.findall(r'"([^"]+)"', result.content)
print(matches)
prompt = "" prompt = ""
for h in history: for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1]) prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history(question,history=prompt, with_similarity=True) answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches,history=prompt, with_similarity=True)
docs_json = [] docs_json = []
for d in docs: for d in docs:
j ={} j ={}
...@@ -201,10 +209,12 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): ...@@ -201,10 +209,12 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
last_turn_id = crud.get_last_turn_num(str(session_id)) last_turn_id = crud.get_last_turn_num(str(session_id))
history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id) history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
result = ext.extend_query(question, history)
matches = re.findall(r'"([^"]+)"', result.content)
prompt = "" prompt = ""
for h in history: for h in history:
prompt += "问:{}\n答:{}\n\n".format(h[0], h[1]) prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
answer, docs = my_chat.chat_with_history(question, history=prompt, with_similarity=True) answer, docs = my_chat.chat_with_history_with_ext(question,ext=matches, history=prompt, with_similarity=True)
docs_json = [] docs_json = []
for d in docs: for d in docs:
j = {} j = {}
......
...@@ -31,6 +31,43 @@ class GetSimilarity: ...@@ -31,6 +31,43 @@ class GetSimilarity:
def get_rerank_docs(self): def get_rerank_docs(self):
return self.rerank_docs return self.rerank_docs
class GetSimilarityWithExt:
def __init__(self, _question, _faiss_db: VectorStore_FAISS):
self.question = _question
self.faiss_db = _faiss_db
self.similarity_docs = self.get_text_similarity_with_ext()
self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
self.rerank_docs = ""
def get_rerank(self, reranker: BgeRerank, top_k=5):
question = '\n'.join(self.question)
print(question)
rerank_docs = reranker.compress_documents(self.similarity_docs, question)
d_list = []
for d in rerank_docs[:top_k]:
d_list.append(d)
self.rerank_docs = rerank_docs[:top_k]
return self.faiss_db.join_document(d_list)
def get_similarity_doc(self):
return self.similarity_doc_txt
def get_similarity_docs(self):
return self.similarity_docs
def get_rerank_docs(self):
return self.rerank_docs
def get_text_similarity_with_ext(self):
similarity_docs = []
for q in self.question:
print(q)
similarity_doc = self.faiss_db.get_text_similarity(q)
similarity_docs.extend(similarity_doc)
return similarity_docs
DEFAULT_PROMPT = """作为一个向量检索助手,你的任务是结合历史记录,从不同角度,为“原问题”生成个不同版本的“检索词”,从而提高向量检索的语义丰富度,提高向量检索的精度。生成的问题要求指向对象清晰明确,并与“原问题语言相同”。例如: DEFAULT_PROMPT = """作为一个向量检索助手,你的任务是结合历史记录,从不同角度,为“原问题”生成个不同版本的“检索词”,从而提高向量检索的语义丰富度,提高向量检索的精度。生成的问题要求指向对象清晰明确,并与“原问题语言相同”。例如:
历史记录: 历史记录:
''' '''
......
...@@ -10,7 +10,7 @@ from src.llm.ernie_with_sdk import ChatERNIESerLLM ...@@ -10,7 +10,7 @@ from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion from qianfan import ChatCompletion
from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD from src.pgdb.chat.crud import CRUD
from src.server.get_similarity import GetSimilarity from src.server.get_similarity import GetSimilarity,GetSimilarityWithExt
from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import ( from src.config.consts import (
CHAT_DB_USER, CHAT_DB_USER,
...@@ -82,6 +82,9 @@ class QA: ...@@ -82,6 +82,9 @@ class QA:
def get_similarity_origin(self, _aquestion): def get_similarity_origin(self, _aquestion):
return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db) return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db)
def get_similarity_with_ext_origin(self, _ext):
return GetSimilarityWithExt(_question=_ext, _faiss_db=self.faiss_db)
# 一次性直接给出所有的答案 # 一次性直接给出所有的答案
def chat(self, _question,with_similarity=False): def chat(self, _question,with_similarity=False):
self.cur_oquestion = _question self.cur_oquestion = _question
...@@ -139,6 +142,35 @@ class QA: ...@@ -139,6 +142,35 @@ class QA:
return self.cur_answer, similarity_docs return self.cur_answer, similarity_docs
return self.cur_answer return self.cur_answer
def chat_with_history_with_ext(self, _question,ext,history,with_similarity=False):
self.cur_oquestion = _question
if self.contains_blocked_keywords(_question):
self.cur_answer = SAFE_RESPONSE
else:
# self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
similarity = self.get_similarity_with_ext_origin(ext)
if self.rerank:
self.cur_similarity = similarity.get_rerank(self.rerank_model)
else:
self.cur_similarity = similarity.get_similarity_doc()
similarity_docs = similarity.get_similarity_docs()
rerank_docs = similarity.get_rerank_docs()
print(rerank_docs)
self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion)
if not _question:
return ""
self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion)
if self.contains_blocked_keywords(self.cur_answer):
self.cur_answer = SAFE_RESPONSE
self.update_history()
if with_similarity:
if self.rerank:
return self.cur_answer, rerank_docs
else:
return self.cur_answer, similarity_docs
return self.cur_answer
# 异步输出,逐渐输出答案 # 异步输出,逐渐输出答案
async def async_chat(self, _question): async def async_chat(self, _question):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment