Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
14a6b87b
Commit
14a6b87b
authored
Apr 30, 2024
by
陈正乐
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
修改在询问模型时,所需参数个数
parent
1e9efa79
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
147 additions
and
25 deletions
+147
-25
consts.py
src/config/consts.py
+1
-1
similarity.py
src/pgdb/knowledge/similarity.py
+1
-1
get_similarity.py
src/server/get_similarity.py
+12
-0
qa.py
src/server/qa.py
+41
-14
chat_table_test.py
test/chat_table_test.py
+3
-2
gradio_text.py
test/gradio_text.py
+62
-0
k_store_test.py
test/k_store_test.py
+2
-2
lk_test.py
test/lk_test.py
+25
-5
No files found.
src/config/consts.py
View file @
14a6b87b
...
...
@@ -41,4 +41,4 @@ INDEX_NAME = 'know'
# =============================
# 知识相关资料配置
# =============================
KNOWLEDGE_PATH
=
'C:
\\
Users
\\
15663
\\
Desktop
\\
work
\\
llm_gjjs
\\
兴火燎原知识库
\\
兴火燎原知识库
\\
law
\\
pdf
'
KNOWLEDGE_PATH
=
'C:
\\
Users
\\
15663
\\
Desktop
\\
低空经济数据库
'
src/pgdb/knowledge/similarity.py
View file @
14a6b87b
...
...
@@ -227,7 +227,7 @@ class VectorStore_FAISS(FAISS):
# return deduplicated_documents
@staticmethod
def
_
join_document
(
docs
:
List
[
Document
])
->
str
:
def
join_document
(
docs
:
List
[
Document
])
->
str
:
print
(
docs
)
return
""
.
join
([
doc
.
page_content
for
doc
in
docs
])
...
...
src/server/get_similarity.py
0 → 100644
View file @
14a6b87b
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
class
GetSimilarity
:
def
__init__
(
self
,
_question
,
_faiss_db
:
VectorStore_FAISS
):
self
.
question
=
_question
self
.
faiss_db
=
_faiss_db
self
.
similarity_doc
=
self
.
faiss_db
.
join_document
(
self
.
faiss_db
.
get_text_similarity
(
"什么是低空飞行"
))
def
get_similarity_doc
(
self
):
return
self
.
similarity_doc
src/server/qa.py
View file @
14a6b87b
...
...
@@ -11,12 +11,23 @@ from src.llm.ernie_with_sdk import ChatERNIESerLLM
from
qianfan
import
ChatCompletion
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.pgdb.chat.crud
import
CRUD
from
src.server.get_similarity
import
GetSimilarity
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.config.consts
import
(
CHAT_DB_USER
,
CHAT_DB_HOST
,
CHAT_DB_PORT
,
CHAT_DB_DBNAME
,
CHAT_DB_PASSWORD
CHAT_DB_PASSWORD
,
EMBEEDING_MODEL_PATH
,
FAISS_STORE_PATH
,
INDEX_NAME
,
VEC_DB_HOST
,
VEC_DB_PASSWORD
,
VEC_DB_PORT
,
VEC_DB_USER
,
VEC_DB_DBNAME
,
SIMILARITY_SHOW_NUMBER
)
sys
.
path
.
append
(
"../.."
)
...
...
@@ -29,32 +40,40 @@ PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=promp
class
QA
:
def
__init__
(
self
,
_prompt
,
_base_llm
,
_llm_kwargs
,
_prompt_kwargs
,
_db
,
_chat_id
):
def
__init__
(
self
,
_prompt
,
_base_llm
,
_llm_kwargs
,
_prompt_kwargs
,
_db
,
_chat_id
,
_faiss_db
):
self
.
prompt
=
_prompt
self
.
base_llm
=
_base_llm
self
.
llm_kwargs
=
_llm_kwargs
self
.
prompt_kwargs
=
_prompt_kwargs
self
.
db
=
_db
self
.
chat_id
=
_chat_id
self
.
faiss_db
=
_faiss_db
self
.
crud
=
CRUD
(
self
.
db
)
self
.
history
=
self
.
crud
.
get_history
(
self
.
chat_id
)
self
.
llm
=
LLMChain
(
llm
=
self
.
base_llm
,
prompt
=
self
.
prompt
,
llm_kwargs
=
self
.
llm_kwargs
)
self
.
cur_answer
=
""
self
.
cur_question
=
""
# 为所给问题返回similarity文本
def
get_similarity
(
self
,
_aquestion
):
return
GetSimilarity
(
_question
=
_aquestion
,
_faiss_db
=
self
.
faiss_db
)
.
get_similarity_doc
()
# 一次性直接给出所有的答案
async
def
chat
(
self
,
*
args
):
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
args
)})
async
def
chat
(
self
,
_question
):
similarity
=
self
.
get_similarity
(
_aquestion
=
_question
)
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
similarity
,
_question
))})
self
.
cur_answer
=
""
if
not
args
:
if
not
_question
:
return
""
self
.
cur_answer
=
self
.
llm
.
run
({
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
args
)})
self
.
cur_answer
=
self
.
llm
.
run
({
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
similarity
,
_question
)
)})
return
self
.
cur_answer
# 异步输出,逐渐输出答案
async
def
async_chat
(
self
,
*
args
):
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
args
)})
async
def
async_chat
(
self
,
_question
):
similarity
=
self
.
get_similarity
(
_aquestion
=
_question
)
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
similarity
,
_question
))})
callback
=
AsyncIteratorCallbackHandler
()
async
def
wrap_done
(
fn
:
Awaitable
,
event
:
asyncio
.
Event
):
try
:
await
fn
...
...
@@ -64,8 +83,9 @@ class QA:
print
(
f
"Caught exception: {e}"
)
finally
:
event
.
set
()
task
=
asyncio
.
create_task
(
wrap_done
(
self
.
llm
.
arun
({
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
args
)},
callbacks
=
[
callback
]),
wrap_done
(
self
.
llm
.
arun
({
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
similarity
,
_question
)
)},
callbacks
=
[
callback
]),
callback
.
done
))
self
.
cur_answer
=
""
async
for
token
in
callback
.
aiter
():
...
...
@@ -73,11 +93,10 @@ class QA:
yield
f
"{self.cur_answer}"
print
(
datetime
.
now
())
await
task
print
(
'----------------'
,
self
.
cur_question
)
print
(
'================'
,
self
.
cur_answer
)
print
(
'----------------'
,
self
.
cur_question
)
print
(
'================'
,
self
.
cur_answer
)
print
(
datetime
.
now
())
def
get_history
(
self
):
return
self
.
history
...
...
@@ -97,8 +116,16 @@ if __name__ == "__main__":
port
=
CHAT_DB_PORT
,
)
base_llm
=
ChatERNIESerLLM
(
chat_completion
=
ChatCompletion
(
ak
=
"pT7sV1smp4AeDl0LjyZuHBV9"
,
sk
=
"b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"
))
my_chat
=
QA
(
PROMPT1
,
base_llm
,
{
"temperature"
:
0.9
},
[
'context'
,
'question'
],
_db
=
c_db
,
_chat_id
=
'2'
)
print
(
my_chat
.
async_chat
(
"当别人想你说你好的时候,你也应该说你好"
,
"你好"
))
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
my_chat
=
QA
(
PROMPT1
,
base_llm
,
{
"temperature"
:
0.9
},
[
'context'
,
'question'
],
_db
=
c_db
,
_chat_id
=
'2'
,
_faiss_db
=
vecstore_faiss
)
print
(
my_chat
.
chat
(
"什么是低空经济"
))
my_chat
.
update_history
()
time
.
sleep
(
20
)
print
(
my_chat
.
cur_answer
)
test/chat_table_test.py
View file @
14a6b87b
import
sys
sys
.
path
.
append
(
"../"
)
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.pgdb.chat.chat_table
import
Chat
from
src.pgdb.chat.c_user_table
import
CUser
...
...
@@ -12,14 +13,14 @@ from src.config.consts import (
CHAT_DB_PASSWORD
)
sys
.
path
.
append
(
"../"
)
"""测试会话相关数据可的连接"""
def
test
():
c_db
=
UPostgresDB
(
host
=
CHAT_DB_HOST
,
database
=
CHAT_DB_DBNAME
,
user
=
CHAT_DB_USER
,
password
=
CHAT_DB_PASSWORD
,
port
=
CHAT_DB_PORT
,
)
print
(
c_db
)
crud
=
CRUD
(
c_db
)
crud
.
create_table
()
crud
.
insert_turn_qa
(
"2"
,
"wen4"
,
"da1"
,
1
,
0
)
...
...
test/gradio_text.py
0 → 100644
View file @
14a6b87b
# -*- coding: utf-8 -*-
import
gradio
as
gr
from
langchain.prompts
import
PromptTemplate
from
src.llm.ernie_with_sdk
import
ChatERNIESerLLM
from
qianfan
import
ChatCompletion
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.server.get_similarity
import
GetSimilarity
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.config.consts
import
(
CHAT_DB_USER
,
CHAT_DB_HOST
,
CHAT_DB_PORT
,
CHAT_DB_DBNAME
,
CHAT_DB_PASSWORD
,
EMBEEDING_MODEL_PATH
,
FAISS_STORE_PATH
,
INDEX_NAME
,
VEC_DB_HOST
,
VEC_DB_PASSWORD
,
VEC_DB_PORT
,
VEC_DB_USER
,
VEC_DB_DBNAME
,
SIMILARITY_SHOW_NUMBER
)
from
src.server.qa
import
QA
prompt1
=
"""'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1
=
PromptTemplate
(
input_variables
=
[
"context"
,
"question"
],
template
=
prompt1
)
def
main
():
c_db
=
UPostgresDB
(
host
=
CHAT_DB_HOST
,
database
=
CHAT_DB_DBNAME
,
user
=
CHAT_DB_USER
,
password
=
CHAT_DB_PASSWORD
,
port
=
CHAT_DB_PORT
,
)
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
base_llm
=
ChatERNIESerLLM
(
chat_completion
=
ChatCompletion
(
ak
=
"pT7sV1smp4AeDl0LjyZuHBV9"
,
sk
=
"b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"
))
my_chat
=
QA
(
PROMPT1
,
base_llm
,
{
"temperature"
:
0.9
},
[
'context'
,
'question'
],
_db
=
c_db
,
_chat_id
=
'2'
,
_faiss_db
=
vecstore_faiss
)
with
gr
.
Blocks
()
as
demo
:
with
gr
.
Row
():
in1
=
gr
.
Textbox
(
show_label
=
True
,
lines
=
10
,
visible
=
False
)
in2
=
gr
.
Textbox
(
show_label
=
True
,
lines
=
10
)
with
gr
.
Row
():
qabtn
=
gr
.
Button
(
"SUBMIT"
)
out
=
gr
.
Textbox
(
show_label
=
True
,
lines
=
10
)
qabtn
.
click
(
my_chat
.
async_chat
,
[
in2
],
[
out
])
demo
.
queue
()
.
launch
(
share
=
False
,
inbrowser
=
True
,
server_name
=
"192.168.100.76"
,
server_port
=
8888
)
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
test/k_store_test.py
View file @
14a6b87b
...
...
@@ -78,9 +78,9 @@ def test_faiss_load():
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
print
(
vecstore_faiss
.
_join_document
(
vecstore_faiss
.
get_text_similarity
(
"征信业务有什么情况
"
)))
print
(
vecstore_faiss
.
join_document
(
vecstore_faiss
.
get_text_similarity
(
"什么是低空飞行
"
)))
if
__name__
==
"__main__"
:
#
test_faiss_from_dir()
test_faiss_from_dir
()
test_faiss_load
()
test/lk_test.py
View file @
14a6b87b
from
functools
import
reduce
def
add_three
(
x
,
y
):
return
x
+
y
li
=
[
1
,
2
,
3
,
5
]
reduce
(
add_three
,
li
)
#=> 11
class
Solution
:
@staticmethod
def
numDecodings
(
s
:
str
)
->
int
:
length
=
len
(
s
)
ans
=
[
0
for
i
in
range
(
length
+
1
)]
ans
[
0
]
=
0
ans
[
1
]
=
1
for
i
in
range
(
1
,
length
+
1
):
print
(
i
)
if
s
[
i
-
1
]
+
s
[
i
]
==
'10'
or
s
[
i
-
1
]
+
s
[
i
]
==
'11'
or
s
[
i
-
1
]
+
s
[
i
]
==
'12'
or
s
[
i
-
1
]
+
s
[
i
]
==
'13'
or
s
[
i
-
1
]
+
s
[
i
]
==
'14'
or
s
[
i
-
1
]
+
s
[
i
]
==
'15'
or
s
[
i
-
1
]
+
s
[
i
]
==
'16'
or
s
[
i
-
1
]
+
s
[
i
]
==
'17'
or
s
[
i
-
1
]
+
s
[
i
]
==
'18'
or
s
[
i
-
1
]
+
s
[
i
]
==
'19'
or
s
[
i
-
1
]
+
s
[
i
]
==
'20'
or
s
[
i
-
1
]
+
s
[
i
]
==
'21'
or
s
[
i
-
1
]
+
s
[
i
]
==
'22'
or
s
[
i
-
1
]
+
s
[
i
]
==
'23'
or
s
[
i
-
1
]
+
s
[
i
]
==
'24'
or
s
[
i
-
1
]
+
s
[
i
]
==
'25'
or
s
[
i
-
1
]
+
s
[
i
]
==
'26'
:
if
s
[
i
]
==
'0'
:
ans
[
i
]
=
ans
[
i
-
1
]
+
1
else
:
ans
[
i
]
=
ans
[
i
-
1
]
+
2
else
:
ans
[
i
]
=
ans
[
i
-
1
]
+
1
print
(
ans
)
return
ans
[
length
-
1
]
Solution
.
numDecodings
(
"226"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment