Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
4f95b54c
Commit
4f95b54c
authored
Jun 28, 2024
by
文靖昊
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
修改重排bug,新增历史对话功能
parent
582deb2e
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
71 additions
and
22 deletions
+71
-22
consts.py
src/config/consts.py
+13
-7
web.py
src/controller/web.py
+12
-6
similarity.py
src/pgdb/knowledge/similarity.py
+0
-1
vec_txt_table.py
src/pgdb/knowledge/vec_txt_table.py
+1
-1
get_similarity.py
src/server/get_similarity.py
+6
-2
qa.py
src/server/qa.py
+39
-5
No files found.
src/config/consts.py
View file @
4f95b54c
# =============================
# =============================
# 资料存储数据库配置
# 资料存储数据库配置
# =============================
# =============================
VEC_DB_HOST
=
'192.168.10.
93
'
VEC_DB_HOST
=
'192.168.10.
189
'
VEC_DB_DBNAME
=
'lae'
VEC_DB_DBNAME
=
'lae'
VEC_DB_USER
=
'postgres'
VEC_DB_USER
=
'postgres'
VEC_DB_PASSWORD
=
'111111'
VEC_DB_PASSWORD
=
'111111'
...
@@ -10,7 +10,7 @@ VEC_DB_PORT = '5433'
...
@@ -10,7 +10,7 @@ VEC_DB_PORT = '5433'
# =============================
# =============================
# 聊天相关数据库配置
# 聊天相关数据库配置
# =============================
# =============================
CHAT_DB_HOST
=
'192.168.10.
93
'
CHAT_DB_HOST
=
'192.168.10.
189
'
CHAT_DB_DBNAME
=
'lae'
CHAT_DB_DBNAME
=
'lae'
CHAT_DB_USER
=
'postgres'
CHAT_DB_USER
=
'postgres'
CHAT_DB_PASSWORD
=
'111111'
CHAT_DB_PASSWORD
=
'111111'
...
@@ -19,12 +19,12 @@ CHAT_DB_PORT = '5433'
...
@@ -19,12 +19,12 @@ CHAT_DB_PORT = '5433'
# =============================
# =============================
# 向量化模型路径配置
# 向量化模型路径配置
# =============================
# =============================
EMBEEDING_MODEL_PATH
=
'
/app/
bge-large-zh-v1.5'
EMBEEDING_MODEL_PATH
=
'
D:
\\
work
\\
py
\\
LAE
\\
bge-large-zh-v1.5'
# =============================
# =============================
# 重排序模型路径配置
# 重排序模型路径配置
# =============================
# =============================
RERANK_MODEL_PATH
=
'
/app/
bge-reranker-large'
RERANK_MODEL_PATH
=
'
D:
\\
work
\\
py
\\
LAE
\\
bge-reranker-large'
# RERANK_MODEL_PATH = 'BAAI/bge-reranker-large'
# RERANK_MODEL_PATH = 'BAAI/bge-reranker-large'
# =============================
# =============================
...
@@ -35,19 +35,19 @@ LLM_SERVER_URL = '192.168.10.102:8002'
...
@@ -35,19 +35,19 @@ LLM_SERVER_URL = '192.168.10.102:8002'
# =============================
# =============================
# FAISS相似性查找配置
# FAISS相似性查找配置
# =============================
# =============================
SIMILARITY_SHOW_NUMBER
=
5
SIMILARITY_SHOW_NUMBER
=
10
SIMILARITY_THRESHOLD
=
0.8
SIMILARITY_THRESHOLD
=
0.8
# =============================
# =============================
# FAISS向量库文件存储路径配置
# FAISS向量库文件存储路径配置
# =============================
# =============================
FAISS_STORE_PATH
=
'
/app/
faiss'
FAISS_STORE_PATH
=
'
D:
\\
work
\\
py
\\
LAE
\\
faiss'
INDEX_NAME
=
'know'
INDEX_NAME
=
'know'
# =============================
# =============================
# 知识相关资料配置
# 知识相关资料配置
# =============================
# =============================
KNOWLEDGE_PATH
=
'
/app/lae_data
'
KNOWLEDGE_PATH
=
'
D:
\\
work
\\
py
\\
LAE
\\
testdoc
'
# =============================
# =============================
# gradio服务相关配置
# gradio服务相关配置
...
@@ -64,6 +64,12 @@ prompt1 = """'''
...
@@ -64,6 +64,12 @@ prompt1 = """'''
请你根据上述已知资料回答下面的问题,问题如下:
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
{question}"""
prompt_enhancement_history_template
=
"""{history}
上面是之前的对话,下面是可参考的内容,参考内容中如果和问题不符合,可以不用参考。
{context}
请结合上述内容回答以下问题,不要提无关内容:
{question}
"""
# =============================
# =============================
# NLP_BERT模型路径配置
# NLP_BERT模型路径配置
# =============================
# =============================
...
...
src/controller/web.py
View file @
4f95b54c
...
@@ -32,7 +32,7 @@ from src.config.consts import (
...
@@ -32,7 +32,7 @@ from src.config.consts import (
VEC_DB_USER
,
VEC_DB_USER
,
VEC_DB_DBNAME
,
VEC_DB_DBNAME
,
SIMILARITY_SHOW_NUMBER
,
SIMILARITY_SHOW_NUMBER
,
prompt
1
prompt
_enhancement_history_template
)
)
app
=
FastAPI
()
app
=
FastAPI
()
app
.
add_middleware
(
app
.
add_middleware
(
...
@@ -62,8 +62,8 @@ base_llm = ChatOpenAI(
...
@@ -62,8 +62,8 @@ base_llm = ChatOpenAI(
model_name
=
'Qwen2-7B'
,
model_name
=
'Qwen2-7B'
,
verbose
=
True
verbose
=
True
)
)
my_chat
=
QA
(
PromptTemplate
(
input_variables
=
[
"
context"
,
"question"
],
template
=
prompt1
),
base_llm
,
my_chat
=
QA
(
PromptTemplate
(
input_variables
=
[
"
history"
,
"context"
,
"question"
],
template
=
prompt_enhancement_history_template
),
base_llm
,
{
"temperature"
:
0.9
},
[
'
context'
,
'question'
],
_db
=
c_db
,
_faiss_db
=
vecstore_faiss
)
{
"temperature"
:
0.9
},
[
'
history'
,
'context'
,
'question'
],
_db
=
c_db
,
_faiss_db
=
vecstore_faiss
,
rerank
=
True
)
@app.post
(
'/api/login'
)
@app.post
(
'/api/login'
)
def
login
(
phone_request
:
PhoneLoginRequest
):
def
login
(
phone_request
:
PhoneLoginRequest
):
...
@@ -156,10 +156,13 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
...
@@ -156,10 +156,13 @@ def question(chat_request: ChatRequest, token: str = Header(None)):
question
=
chat_request
.
question
question
=
chat_request
.
question
crud
=
CRUD
(
_db
=
c_db
)
crud
=
CRUD
(
_db
=
c_db
)
history
=
[]
history
=
[]
if
session_id
=
=
""
:
if
session_id
!
=
""
:
history
=
crud
.
get_last_history
(
str
(
session_id
))
history
=
crud
.
get_last_history
(
str
(
session_id
))
# answer = my_chat.chat(question)
# answer = my_chat.chat(question)
answer
,
docs
=
my_chat
.
chat
(
question
,
with_similarity
=
True
)
prompt
=
""
for
h
in
history
:
prompt
+=
"问:{}
\n
答:{}
\n\n
"
.
format
(
h
[
0
],
h
[
1
])
answer
,
docs
=
my_chat
.
chat_with_history
(
question
,
history
=
prompt
,
with_similarity
=
True
)
docs_json
=
[]
docs_json
=
[]
for
d
in
docs
:
for
d
in
docs
:
j
=
{}
j
=
{}
...
@@ -198,7 +201,10 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
...
@@ -198,7 +201,10 @@ def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
crud
=
CRUD
(
_db
=
c_db
)
crud
=
CRUD
(
_db
=
c_db
)
last_turn_id
=
crud
.
get_last_turn_num
(
str
(
session_id
))
last_turn_id
=
crud
.
get_last_turn_num
(
str
(
session_id
))
history
=
crud
.
get_last_history_before_turn_id
(
str
(
session_id
),
last_turn_id
)
history
=
crud
.
get_last_history_before_turn_id
(
str
(
session_id
),
last_turn_id
)
answer
,
docs
=
my_chat
.
chat
(
question
,
with_similarity
=
True
)
prompt
=
""
for
h
in
history
:
prompt
+=
"问:{}
\n
答:{}
\n\n
"
.
format
(
h
[
0
],
h
[
1
])
answer
,
docs
=
my_chat
.
chat_with_history
(
question
,
history
=
prompt
,
with_similarity
=
True
)
docs_json
=
[]
docs_json
=
[]
for
d
in
docs
:
for
d
in
docs
:
j
=
{}
j
=
{}
...
...
src/pgdb/knowledge/similarity.py
View file @
4f95b54c
...
@@ -227,7 +227,6 @@ class VectorStore_FAISS(FAISS):
...
@@ -227,7 +227,6 @@ class VectorStore_FAISS(FAISS):
@staticmethod
@staticmethod
def
join_document
(
docs
:
List
[
Document
])
->
str
:
def
join_document
(
docs
:
List
[
Document
])
->
str
:
print
(
docs
)
return
""
.
join
([
doc
.
page_content
for
doc
in
docs
])
return
""
.
join
([
doc
.
page_content
for
doc
in
docs
])
@staticmethod
@staticmethod
...
...
src/pgdb/knowledge/vec_txt_table.py
View file @
4f95b54c
...
@@ -35,7 +35,7 @@ class TxtVector:
...
@@ -35,7 +35,7 @@ class TxtVector:
query
=
f
"SELECT paragraph_id,text FROM vec_txt WHERE vector_id =
%
s"
query
=
f
"SELECT paragraph_id,text FROM vec_txt WHERE vector_id =
%
s"
self
.
db
.
execute_args
(
query
,
[
search
])
self
.
db
.
execute_args
(
query
,
[
search
])
answer
=
self
.
db
.
fetchall
()
answer
=
self
.
db
.
fetchall
()
print
(
answer
)
#
print(answer)
return
answer
[
0
]
return
answer
[
0
]
def
create_table
(
self
):
def
create_table
(
self
):
...
...
src/server/get_similarity.py
View file @
4f95b54c
...
@@ -7,11 +7,15 @@ class GetSimilarity:
...
@@ -7,11 +7,15 @@ class GetSimilarity:
self
.
faiss_db
=
_faiss_db
self
.
faiss_db
=
_faiss_db
self
.
similarity_docs
=
self
.
faiss_db
.
get_text_similarity
(
self
.
question
)
self
.
similarity_docs
=
self
.
faiss_db
.
get_text_similarity
(
self
.
question
)
self
.
similarity_doc_txt
=
self
.
faiss_db
.
join_document
(
self
.
similarity_docs
)
self
.
similarity_doc_txt
=
self
.
faiss_db
.
join_document
(
self
.
similarity_docs
)
self
.
rerank_docs
=
""
def
get_rerank
(
self
,
reranker
:
BgeRerank
,
top_k
=
5
):
def
get_rerank
(
self
,
reranker
:
BgeRerank
,
top_k
=
5
):
rerank_docs
=
reranker
.
compress_documents
(
self
.
similarity_docs
,
self
.
question
)
rerank_docs
=
reranker
.
compress_documents
(
self
.
similarity_docs
,
self
.
question
)
self
.
rerank_docs
=
rerank_docs
d_list
=
[]
return
self
.
faiss_db
.
join_document
([
d
[
1
]
for
d
in
rerank_docs
[:
top_k
]])
for
d
in
rerank_docs
[:
top_k
]:
d_list
.
append
(
d
)
self
.
rerank_docs
=
rerank_docs
[:
top_k
]
return
self
.
faiss_db
.
join_document
(
d_list
)
def
get_similarity_doc
(
self
):
def
get_similarity_doc
(
self
):
return
self
.
similarity_doc_txt
return
self
.
similarity_doc_txt
...
...
src/server/qa.py
View file @
4f95b54c
...
@@ -37,8 +37,14 @@ prompt1 = """'''
...
@@ -37,8 +37,14 @@ prompt1 = """'''
'''
'''
请你根据上述已知资料回答下面的问题,问题如下:
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
{question}"""
PROMPT1
=
PromptTemplate
(
input_variables
=
[
"context"
,
"question"
],
template
=
prompt1
)
PROMPT1
=
PromptTemplate
(
input_variables
=
[
"context"
,
"question"
],
template
=
prompt1
)
# 预设的安全响应
# 预设的安全响应
SAFE_RESPONSE
=
"您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。"
SAFE_RESPONSE
=
"您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。"
BLOCKED_KEYWORDS
=
[
"文心一言"
,
"百度"
,
"模型"
]
BLOCKED_KEYWORDS
=
[
"文心一言"
,
"百度"
,
"模型"
]
...
@@ -90,10 +96,6 @@ class QA:
...
@@ -90,10 +96,6 @@ class QA:
self
.
cur_similarity
=
similarity
.
get_similarity_doc
()
self
.
cur_similarity
=
similarity
.
get_similarity_doc
()
similarity_docs
=
similarity
.
get_similarity_docs
()
similarity_docs
=
similarity
.
get_similarity_docs
()
rerank_docs
=
similarity
.
get_rerank_docs
()
rerank_docs
=
similarity
.
get_rerank_docs
()
print
(
"============== similarity =============="
)
print
(
similarity_docs
)
print
(
"============== rerank =============="
)
print
(
rerank_docs
)
self
.
cur_question
=
self
.
prompt
.
format
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
self
.
cur_question
=
self
.
prompt
.
format
(
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
if
not
_question
:
if
not
_question
:
return
""
return
""
...
@@ -103,9 +105,41 @@ class QA:
...
@@ -103,9 +105,41 @@ class QA:
self
.
update_history
()
self
.
update_history
()
if
with_similarity
:
if
with_similarity
:
return
self
.
cur_answer
,
similarity_docs
if
self
.
rerank
:
return
self
.
cur_answer
,
rerank_docs
else
:
return
self
.
cur_answer
,
similarity_docs
return
self
.
cur_answer
def
chat_with_history
(
self
,
_question
,
history
,
with_similarity
=
False
):
self
.
cur_oquestion
=
_question
if
self
.
contains_blocked_keywords
(
_question
):
self
.
cur_answer
=
SAFE_RESPONSE
else
:
# self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
similarity
=
self
.
get_similarity_origin
(
_aquestion
=
self
.
cur_oquestion
)
if
self
.
rerank
:
self
.
cur_similarity
=
similarity
.
get_rerank
(
self
.
rerank_model
)
else
:
self
.
cur_similarity
=
similarity
.
get_similarity_doc
()
similarity_docs
=
similarity
.
get_similarity_docs
()
rerank_docs
=
similarity
.
get_rerank_docs
()
self
.
cur_question
=
self
.
prompt
.
format
(
history
=
history
,
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
if
not
_question
:
return
""
self
.
cur_answer
=
self
.
llm
.
run
(
history
=
history
,
context
=
self
.
cur_similarity
,
question
=
self
.
cur_oquestion
)
if
self
.
contains_blocked_keywords
(
self
.
cur_answer
):
self
.
cur_answer
=
SAFE_RESPONSE
self
.
update_history
()
if
with_similarity
:
if
self
.
rerank
:
return
self
.
cur_answer
,
rerank_docs
else
:
return
self
.
cur_answer
,
similarity_docs
return
self
.
cur_answer
return
self
.
cur_answer
# 异步输出,逐渐输出答案
# 异步输出,逐渐输出答案
async
def
async_chat
(
self
,
_question
):
async
def
async_chat
(
self
,
_question
):
self
.
cur_oquestion
=
_question
self
.
cur_oquestion
=
_question
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment