Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
7e960d29
Commit
7e960d29
authored
May 27, 2024
by
周峻哲
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
避免回答大模型类型的问题
parent
abe4fd0a
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
13 deletions
+40
-13
qa.py
src/server/qa.py
+40
-13
No files found.
src/server/qa.py
View file @
7e960d29
...
...
@@ -37,6 +37,9 @@ prompt1 = """'''
{question}"""
PROMPT1
=
PromptTemplate
(
input_variables
=
[
"context"
,
"question"
],
template
=
prompt1
)
# 预设的安全响应
SAFE_RESPONSE
=
"抱歉,我无法回答这个问题。"
BLOCKED_KEYWORDS
=
[
"文心一言"
,
"百度"
,
"模型"
]
class
QA
:
def
__init__
(
self
,
_prompt
,
_base_llm
,
_llm_kwargs
,
_prompt_kwargs
,
_db
,
_faiss_db
):
...
...
@@ -55,6 +58,10 @@ class QA:
self
.
cur_similarity
=
""
self
.
cur_oquestion
=
""
# 检查是否包含敏感信息
def
contains_blocked_keywords
(
self
,
text
):
return
any
(
keyword
in
text
for
keyword
in
BLOCKED_KEYWORDS
)
# 为所给问题返回similarity文本
def
get_similarity
(
self
,
_aquestion
):
return
GetSimilarity
(
_question
=
_aquestion
,
_faiss_db
=
self
.
faiss_db
)
.
get_similarity_doc
()
...
...
@@ -62,19 +69,32 @@ class QA:
# 一次性直接给出所有的答案
def
chat
(
self
,
_question
):
self
.
cur_oquestion
=
_question
self
.
cur_similarity
=
self
.
get_similarity
(
_aquestion
=
self
.
cur_oquestion
)
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))})
self
.
cur_answer
=
""
if
not
_question
:
return
""
self
.
cur_answer
=
self
.
llm
.
run
({
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))})
if
self
.
contains_blocked_keywords
(
_question
):
self
.
cur_answer
=
SAFE_RESPONSE
else
:
self
.
cur_similarity
=
self
.
get_similarity
(
_aquestion
=
self
.
cur_oquestion
)
self
.
cur_question
=
self
.
prompt
.
format
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
if
not
_question
:
return
""
self
.
cur_answer
=
self
.
llm
.
run
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
if
self
.
contains_blocked_keywords
(
self
.
cur_answer
):
self
.
cur_answer
=
SAFE_RESPONSE
self
.
update_history
()
return
self
.
cur_answer
# 异步输出,逐渐输出答案
async
def
async_chat
(
self
,
_question
):
self
.
cur_oquestion
=
_question
history
=
self
.
get_history
()
if
self
.
contains_blocked_keywords
(
_question
):
self
.
cur_answer
=
SAFE_RESPONSE
yield
[(
self
.
cur_oquestion
,
self
.
cur_answer
)]
return
self
.
cur_similarity
=
self
.
get_similarity
(
_aquestion
=
self
.
cur_oquestion
)
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))}
)
self
.
cur_question
=
self
.
prompt
.
format
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
callback
=
AsyncIteratorCallbackHandler
()
async
def
wrap_done
(
fn
:
Awaitable
,
event
:
asyncio
.
Event
):
...
...
@@ -88,13 +108,18 @@ class QA:
event
.
set
()
task
=
asyncio
.
create_task
(
wrap_done
(
self
.
llm
.
arun
(
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))}
,
callbacks
=
[
callback
]),
wrap_done
(
self
.
llm
.
arun
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
,
callbacks
=
[
callback
]),
callback
.
done
))
history
=
self
.
get_history
()
self
.
cur_answer
=
""
history
.
append
((
self
.
cur_oquestion
,
self
.
cur_answer
))
async
for
token
in
callback
.
aiter
():
self
.
cur_answer
=
self
.
cur_answer
+
token
self
.
cur_answer
+=
token
if
self
.
contains_blocked_keywords
(
self
.
cur_answer
):
self
.
cur_answer
=
SAFE_RESPONSE
history
[
-
1
]
=
(
self
.
cur_oquestion
,
self
.
cur_answer
)
yield
history
return
history
[
-
1
]
=
(
self
.
cur_oquestion
,
self
.
cur_answer
)
yield
history
await
task
...
...
@@ -104,9 +129,11 @@ class QA:
return
self
.
history
def
update_history
(
self
):
if
self
.
cur_oquestion
==
''
and
self
.
cur_answer
==
''
:
pass
else
:
if
self
.
cur_oquestion
and
self
.
cur_answer
:
if
not
self
.
history
:
self
.
history
=
[]
# 避免重复添加条目
self
.
history
=
[(
q
,
a
)
for
q
,
a
in
self
.
history
if
q
!=
self
.
cur_oquestion
or
a
!=
""
]
self
.
history
.
append
((
self
.
cur_oquestion
,
self
.
cur_answer
))
self
.
crud
.
update_last
(
chat_id
=
self
.
chat_id
)
self
.
crud
.
insert_turn_qa
(
chat_id
=
self
.
chat_id
,
question
=
self
.
cur_oquestion
,
answer
=
self
.
cur_answer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment