Commit 7e960d29 by 周峻哲

避免回答大模型类型的问题

parent abe4fd0a
......@@ -37,6 +37,9 @@ prompt1 = """'''
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
# 预设的安全响应
SAFE_RESPONSE = "抱歉,我无法回答这个问题。"
BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"]
class QA:
def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db):
......@@ -55,6 +58,10 @@ class QA:
self.cur_similarity = ""
self.cur_oquestion = ""
# 检查是否包含敏感信息
def contains_blocked_keywords(self, text):
return any(keyword in text for keyword in BLOCKED_KEYWORDS)
# 为所给问题返回similarity文本
def get_similarity(self, _aquestion):
return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
......@@ -62,19 +69,32 @@ class QA:
# 一次性直接给出所有的答案
def chat(self, _question):
self.cur_oquestion = _question
self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))})
self.cur_answer = ""
if not _question:
return ""
self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))})
if self.contains_blocked_keywords(_question):
self.cur_answer = SAFE_RESPONSE
else:
self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
if not _question:
return ""
self.cur_answer = self.llm.run(context=self.cur_similarity, question=self.cur_oquestion)
if self.contains_blocked_keywords(self.cur_answer):
self.cur_answer = SAFE_RESPONSE
self.update_history()
return self.cur_answer
# 异步输出,逐渐输出答案
async def async_chat(self, _question):
self.cur_oquestion = _question
history = self.get_history()
if self.contains_blocked_keywords(_question):
self.cur_answer = SAFE_RESPONSE
yield [(self.cur_oquestion, self.cur_answer)]
return
self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))})
self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
callback = AsyncIteratorCallbackHandler()
async def wrap_done(fn: Awaitable, event: asyncio.Event):
......@@ -88,13 +108,18 @@ class QA:
event.set()
task = asyncio.create_task(
wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}, callbacks=[callback]),
wrap_done(self.llm.arun(context=self.cur_similarity, question=self.cur_oquestion, callbacks=[callback]),
callback.done))
history = self.get_history()
self.cur_answer = ""
history.append((self.cur_oquestion, self.cur_answer))
async for token in callback.aiter():
self.cur_answer = self.cur_answer + token
self.cur_answer += token
if self.contains_blocked_keywords(self.cur_answer):
self.cur_answer = SAFE_RESPONSE
history[-1] = (self.cur_oquestion, self.cur_answer)
yield history
return
history[-1] = (self.cur_oquestion, self.cur_answer)
yield history
await task
......@@ -104,9 +129,11 @@ class QA:
return self.history
def update_history(self):
if self.cur_oquestion == '' and self.cur_answer == '':
pass
else:
if self.cur_oquestion and self.cur_answer:
if not self.history:
self.history = []
# 避免重复添加条目
self.history = [(q, a) for q, a in self.history if q != self.cur_oquestion or a != ""]
self.history.append((self.cur_oquestion, self.cur_answer))
self.crud.update_last(chat_id=self.chat_id)
self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment