Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
7e960d29
Commit
7e960d29
authored
May 27, 2024
by
周峻哲
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
避免回答大模型类型的问题
parent
abe4fd0a
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
10 deletions
+37
-10
qa.py
src/server/qa.py
+37
-10
No files found.
src/server/qa.py
View file @
7e960d29
...
@@ -37,6 +37,9 @@ prompt1 = """'''
...
@@ -37,6 +37,9 @@ prompt1 = """'''
{question}"""
{question}"""
PROMPT1
=
PromptTemplate
(
input_variables
=
[
"context"
,
"question"
],
template
=
prompt1
)
PROMPT1
=
PromptTemplate
(
input_variables
=
[
"context"
,
"question"
],
template
=
prompt1
)
# 预设的安全响应
SAFE_RESPONSE
=
"抱歉,我无法回答这个问题。"
BLOCKED_KEYWORDS
=
[
"文心一言"
,
"百度"
,
"模型"
]
class
QA
:
class
QA
:
def
__init__
(
self
,
_prompt
,
_base_llm
,
_llm_kwargs
,
_prompt_kwargs
,
_db
,
_faiss_db
):
def
__init__
(
self
,
_prompt
,
_base_llm
,
_llm_kwargs
,
_prompt_kwargs
,
_db
,
_faiss_db
):
...
@@ -55,6 +58,10 @@ class QA:
...
@@ -55,6 +58,10 @@ class QA:
self
.
cur_similarity
=
""
self
.
cur_similarity
=
""
self
.
cur_oquestion
=
""
self
.
cur_oquestion
=
""
# 检查是否包含敏感信息
def
contains_blocked_keywords
(
self
,
text
):
return
any
(
keyword
in
text
for
keyword
in
BLOCKED_KEYWORDS
)
# 为所给问题返回similarity文本
# 为所给问题返回similarity文本
def
get_similarity
(
self
,
_aquestion
):
def
get_similarity
(
self
,
_aquestion
):
return
GetSimilarity
(
_question
=
_aquestion
,
_faiss_db
=
self
.
faiss_db
)
.
get_similarity_doc
()
return
GetSimilarity
(
_question
=
_aquestion
,
_faiss_db
=
self
.
faiss_db
)
.
get_similarity_doc
()
...
@@ -62,19 +69,32 @@ class QA:
...
@@ -62,19 +69,32 @@ class QA:
# 一次性直接给出所有的答案
# 一次性直接给出所有的答案
def
chat
(
self
,
_question
):
def
chat
(
self
,
_question
):
self
.
cur_oquestion
=
_question
self
.
cur_oquestion
=
_question
if
self
.
contains_blocked_keywords
(
_question
):
self
.
cur_answer
=
SAFE_RESPONSE
else
:
self
.
cur_similarity
=
self
.
get_similarity
(
_aquestion
=
self
.
cur_oquestion
)
self
.
cur_similarity
=
self
.
get_similarity
(
_aquestion
=
self
.
cur_oquestion
)
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))})
self
.
cur_question
=
self
.
prompt
.
format
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
self
.
cur_answer
=
""
if
not
_question
:
if
not
_question
:
return
""
return
""
self
.
cur_answer
=
self
.
llm
.
run
({
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))})
self
.
cur_answer
=
self
.
llm
.
run
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
if
self
.
contains_blocked_keywords
(
self
.
cur_answer
):
self
.
cur_answer
=
SAFE_RESPONSE
self
.
update_history
()
return
self
.
cur_answer
return
self
.
cur_answer
# 异步输出,逐渐输出答案
# 异步输出,逐渐输出答案
async
def
async_chat
(
self
,
_question
):
async
def
async_chat
(
self
,
_question
):
self
.
cur_oquestion
=
_question
self
.
cur_oquestion
=
_question
history
=
self
.
get_history
()
if
self
.
contains_blocked_keywords
(
_question
):
self
.
cur_answer
=
SAFE_RESPONSE
yield
[(
self
.
cur_oquestion
,
self
.
cur_answer
)]
return
self
.
cur_similarity
=
self
.
get_similarity
(
_aquestion
=
self
.
cur_oquestion
)
self
.
cur_similarity
=
self
.
get_similarity
(
_aquestion
=
self
.
cur_oquestion
)
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))}
)
self
.
cur_question
=
self
.
prompt
.
format
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
callback
=
AsyncIteratorCallbackHandler
()
callback
=
AsyncIteratorCallbackHandler
()
async
def
wrap_done
(
fn
:
Awaitable
,
event
:
asyncio
.
Event
):
async
def
wrap_done
(
fn
:
Awaitable
,
event
:
asyncio
.
Event
):
...
@@ -88,13 +108,18 @@ class QA:
...
@@ -88,13 +108,18 @@ class QA:
event
.
set
()
event
.
set
()
task
=
asyncio
.
create_task
(
task
=
asyncio
.
create_task
(
wrap_done
(
self
.
llm
.
arun
(
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))}
,
callbacks
=
[
callback
]),
wrap_done
(
self
.
llm
.
arun
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
,
callbacks
=
[
callback
]),
callback
.
done
))
callback
.
done
))
history
=
self
.
get_history
()
self
.
cur_answer
=
""
self
.
cur_answer
=
""
history
.
append
((
self
.
cur_oquestion
,
self
.
cur_answer
))
history
.
append
((
self
.
cur_oquestion
,
self
.
cur_answer
))
async
for
token
in
callback
.
aiter
():
async
for
token
in
callback
.
aiter
():
self
.
cur_answer
=
self
.
cur_answer
+
token
self
.
cur_answer
+=
token
if
self
.
contains_blocked_keywords
(
self
.
cur_answer
):
self
.
cur_answer
=
SAFE_RESPONSE
history
[
-
1
]
=
(
self
.
cur_oquestion
,
self
.
cur_answer
)
yield
history
return
history
[
-
1
]
=
(
self
.
cur_oquestion
,
self
.
cur_answer
)
history
[
-
1
]
=
(
self
.
cur_oquestion
,
self
.
cur_answer
)
yield
history
yield
history
await
task
await
task
...
@@ -104,9 +129,11 @@ class QA:
...
@@ -104,9 +129,11 @@ class QA:
return
self
.
history
return
self
.
history
def
update_history
(
self
):
def
update_history
(
self
):
if
self
.
cur_oquestion
==
''
and
self
.
cur_answer
==
''
:
if
self
.
cur_oquestion
and
self
.
cur_answer
:
pass
if
not
self
.
history
:
else
:
self
.
history
=
[]
# 避免重复添加条目
self
.
history
=
[(
q
,
a
)
for
q
,
a
in
self
.
history
if
q
!=
self
.
cur_oquestion
or
a
!=
""
]
self
.
history
.
append
((
self
.
cur_oquestion
,
self
.
cur_answer
))
self
.
history
.
append
((
self
.
cur_oquestion
,
self
.
cur_answer
))
self
.
crud
.
update_last
(
chat_id
=
self
.
chat_id
)
self
.
crud
.
update_last
(
chat_id
=
self
.
chat_id
)
self
.
crud
.
insert_turn_qa
(
chat_id
=
self
.
chat_id
,
question
=
self
.
cur_oquestion
,
answer
=
self
.
cur_answer
,
self
.
crud
.
insert_turn_qa
(
chat_id
=
self
.
chat_id
,
question
=
self
.
cur_oquestion
,
answer
=
self
.
cur_answer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment