Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
493cdd59
Commit
493cdd59
authored
Apr 25, 2024
by
陈正乐
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
代码格式化
parent
9d8ee0af
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
409 additions
and
394 deletions
+409
-394
consts.py
src/config/consts.py
+9
-10
__init__.py
src/llm/__init__.py
+1
-2
baichuan.py
src/llm/baichuan.py
+13
-22
chatglm.py
src/llm/chatglm.py
+57
-55
chatglm_openapi.py
src/llm/chatglm_openapi.py
+14
-14
ernie.py
src/llm/ernie.py
+24
-31
ernie_sdk.py
src/llm/ernie_sdk.py
+14
-8
ernie_with_sdk.py
src/llm/ernie_with_sdk.py
+28
-29
loader.py
src/llm/loader.py
+14
-13
spark.py
src/llm/spark.py
+25
-32
wrapper.py
src/llm/wrapper.py
+8
-13
callback.py
src/loader/callback.py
+3
-3
chinese_text_splitter.py
src/loader/chinese_text_splitter.py
+3
-3
config.py
src/loader/config.py
+1
-2
load.py
src/loader/load.py
+0
-0
zh_title_enhance.py
src/loader/zh_title_enhance.py
+1
-1
c_db.py
src/pgdb/chat/c_db.py
+18
-17
c_user_table.py
src/pgdb/chat/c_user_table.py
+4
-3
chat_table.py
src/pgdb/chat/chat_table.py
+4
-3
turn_qa_table.py
src/pgdb/chat/turn_qa_table.py
+4
-3
callback.py
src/pgdb/knowledge/callback.py
+22
-19
k_db.py
src/pgdb/knowledge/k_db.py
+20
-17
pgsqldocstore.py
src/pgdb/knowledge/pgsqldocstore.py
+42
-34
similarity.py
src/pgdb/knowledge/similarity.py
+0
-0
txt_doc_table.py
src/pgdb/knowledge/txt_doc_table.py
+14
-11
vec_txt_table.py
src/pgdb/knowledge/vec_txt_table.py
+14
-10
chat_table_test.py
test/chat_table_test.py
+21
-15
k_store_test.py
test/k_store_test.py
+31
-24
No files found.
src/config/consts.py
View file @
493cdd59
...
@@ -2,19 +2,19 @@
...
@@ -2,19 +2,19 @@
# 资料存储数据库配置
# 资料存储数据库配置
# =============================
# =============================
VEC_DB_HOST
=
'localhost'
VEC_DB_HOST
=
'localhost'
VEC_DB_DBNAME
=
'lae'
VEC_DB_DBNAME
=
'lae'
VEC_DB_USER
=
'postgres'
VEC_DB_USER
=
'postgres'
VEC_DB_PASSWORD
=
'chenzl'
VEC_DB_PASSWORD
=
'chenzl'
VEC_DB_PORT
=
'5432'
VEC_DB_PORT
=
'5432'
# =============================
# =============================
# 聊天相关数据库配置
# 聊天相关数据库配置
# =============================
# =============================
CHAT_DB_HOST
=
'localhost'
CHAT_DB_HOST
=
'localhost'
CHAT_DB_DBNAME
=
'laechat'
CHAT_DB_DBNAME
=
'laechat'
CHAT_DB_USER
=
'postgres'
CHAT_DB_USER
=
'postgres'
CHAT_DB_PASSWORD
=
'chenzl'
CHAT_DB_PASSWORD
=
'chenzl'
CHAT_DB_PORT
=
'5432'
CHAT_DB_PORT
=
'5432'
# =============================
# =============================
# 向量化模型路径配置
# 向量化模型路径配置
...
@@ -41,4 +41,4 @@ INDEX_NAME = 'know'
...
@@ -41,4 +41,4 @@ INDEX_NAME = 'know'
# =============================
# =============================
# 知识相关资料配置
# 知识相关资料配置
# =============================
# =============================
KNOWLEDGE_PATH
=
'C:
\\
Users
\\
15663
\\
Desktop
\\
work
\\
llm_gjjs
\\
兴火燎原知识库
\\
兴火燎原知识库
\\
law
\\
pdf'
KNOWLEDGE_PATH
=
'C:
\\
Users
\\
15663
\\
Desktop
\\
work
\\
llm_gjjs
\\
兴火燎原知识库
\\
兴火燎原知识库
\\
law
\\
pdf'
\ No newline at end of file
src/llm/__init__.py
View file @
493cdd59
"""各种大模型提供的服务"""
"""各种大模型提供的服务"""
\ No newline at end of file
src/llm/baichuan.py
View file @
493cdd59
import
os
import
os
from
typing
import
Dict
,
Optional
,
List
from
typing
import
Dict
,
Optional
,
List
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
,
AutoModelForCausalLM
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
,
AutoModelForCausalLM
from
transformers.generation.utils
import
GenerationConfig
from
transformers.generation.utils
import
GenerationConfig
from
pydantic
import
root_validator
from
pydantic
import
root_validator
class
BaichuanLLM
(
LLM
):
class
BaichuanLLM
(
LLM
):
model_name
:
str
=
"baichuan-inc/Baichuan-13B-Chat"
model_name
:
str
=
"baichuan-inc/Baichuan-13B-Chat"
quantization_bit
:
Optional
[
int
]
=
None
quantization_bit
:
Optional
[
int
]
=
None
tokenizer
:
AutoTokenizer
=
None
tokenizer
:
AutoTokenizer
=
None
model
:
AutoModel
=
None
model
:
AutoModel
=
None
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
return
"chatglm_local"
return
"chatglm_local"
@root_validator
()
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
if
not
values
[
"model_name"
]:
if
not
values
[
"model_name"
]:
raise
ValueError
(
"No model name provided."
)
raise
ValueError
(
"No model name provided."
)
model_name
=
values
[
"model_name"
]
model_name
=
values
[
"model_name"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
model_name
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
...
@@ -43,7 +40,7 @@ class BaichuanLLM(LLM):
...
@@ -43,7 +40,7 @@ class BaichuanLLM(LLM):
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
.
cuda
()
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
.
cuda
()
else
:
else
:
model
=
model
.
half
()
.
cuda
()
model
=
model
.
half
()
.
cuda
()
model
=
model
.
eval
()
model
=
model
.
eval
()
...
@@ -51,14 +48,9 @@ class BaichuanLLM(LLM):
...
@@ -51,14 +48,9 @@ class BaichuanLLM(LLM):
values
[
"model"
]
=
model
values
[
"model"
]
=
model
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
message
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
stop
:
Optional
[
List
[
str
]]
=
None
,
resp
=
self
.
model
.
chat
(
self
.
tokenizer
,
message
)
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
message
=
[]
message
.
append
({
"role"
:
"user"
,
"content"
:
prompt
})
resp
=
self
.
model
.
chat
(
self
.
tokenizer
,
message
)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return
resp
return
resp
\ No newline at end of file
src/llm/chatglm.py
View file @
493cdd59
import
os
import
os
import
requests
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
pydantic
import
root_validator
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
langchain
import
langchain
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
aiohttp
import
aiohttp
import
asyncio
import
asyncio
# 启动llm的缓存
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
# langchain.llm_cache = InMemoryCache()
...
@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM):
...
@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM):
tokenizer
:
AutoTokenizer
=
None
tokenizer
:
AutoTokenizer
=
None
model
:
AutoModel
=
None
model
:
AutoModel
=
None
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
return
"chatglm_local"
return
"chatglm_local"
# @root_validator()
# @root_validator()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
@staticmethod
def
validate_environment
(
values
:
Dict
)
->
Dict
:
if
not
values
[
"model_name"
]:
if
not
values
[
"model_name"
]:
raise
ValueError
(
"No model name provided."
)
raise
ValueError
(
"No model name provided."
)
model_name
=
values
[
"model_name"
]
model_name
=
values
[
"model_name"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if
values
[
"pre_seq_len"
]:
if
values
[
"pre_seq_len"
]:
...
@@ -56,7 +57,7 @@ class ChatGLMLocLLM(LLM):
...
@@ -56,7 +57,7 @@ class ChatGLMLocLLM(LLM):
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
else
:
else
:
model
=
AutoModel
.
from_pretrained
(
model_name
,
config
=
config
,
trust_remote_code
=
True
)
.
half
()
.
cuda
()
model
=
AutoModel
.
from_pretrained
(
model_name
,
config
=
config
,
trust_remote_code
=
True
)
.
half
()
.
cuda
()
if
values
[
"pre_seq_len"
]:
if
values
[
"pre_seq_len"
]:
# P-tuning v2
# P-tuning v2
model
=
model
.
half
()
.
cuda
()
model
=
model
.
half
()
.
cuda
()
...
@@ -64,7 +65,7 @@ class ChatGLMLocLLM(LLM):
...
@@ -64,7 +65,7 @@ class ChatGLMLocLLM(LLM):
if
values
[
"quantization_bit"
]:
if
values
[
"quantization_bit"
]:
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
model
=
model
.
eval
()
model
=
model
.
eval
()
...
@@ -72,30 +73,26 @@ class ChatGLMLocLLM(LLM):
...
@@ -72,30 +73,26 @@ class ChatGLMLocLLM(LLM):
values
[
"model"
]
=
model
values
[
"model"
]
=
model
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return
resp
return
resp
class
ChatGLMSerLLM
(
LLM
):
class
ChatGLMSerLLM
(
LLM
):
# 模型服务url
# 模型服务url
url
:
str
=
"http://127.0.0.1:8000"
url
:
str
=
"http://127.0.0.1:8000"
chat_history
:
dict
=
[]
chat_history
:
dict
=
[]
out_stream
:
bool
=
False
out_stream
:
bool
=
False
cache
:
bool
=
False
cache
:
bool
=
False
@property
@property
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
return
"chatglm3-6b"
return
"chatglm3-6b"
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
resp
=
self
.
_post
(
url
=
self
.
url
+
"/tokens"
,
query
=
self
.
_construct_query
(
text
))
resp
=
self
.
_post
(
url
=
self
.
url
+
"/tokens"
,
query
=
self
.
_construct_query
(
text
))
if
resp
.
status_code
==
200
:
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
resp_json
=
resp
.
json
()
predictions
=
resp_json
[
'response'
]
predictions
=
resp_json
[
'response'
]
...
@@ -103,28 +100,29 @@ class ChatGLMSerLLM(LLM):
...
@@ -103,28 +100,29 @@ class ChatGLMSerLLM(LLM):
return
predictions
return
predictions
else
:
else
:
return
len
(
text
)
return
len
(
text
)
@staticmethod
def
convert_data
(
self
,
data
):
def
convert_data
(
data
):
result
=
[]
result
=
[]
for
item
in
data
:
for
item
in
data
:
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
return
result
return
result
def
_construct_query
(
self
,
prompt
:
str
,
temperature
=
0.95
)
->
Dict
:
def
_construct_query
(
self
,
prompt
:
str
,
temperature
=
0.95
)
->
Dict
:
"""构造请求体
"""构造请求体
"""
"""
# self.chat_history.append({"role": "user", "content": prompt})
# self.chat_history.append({"role": "user", "content": prompt})
query
=
{
query
=
{
"prompt"
:
prompt
,
"prompt"
:
prompt
,
"history"
:
self
.
chat_history
,
"history"
:
self
.
chat_history
,
"max_length"
:
4096
,
"max_length"
:
4096
,
"top_p"
:
0.7
,
"top_p"
:
0.7
,
"temperature"
:
temperature
"temperature"
:
temperature
}
}
return
query
return
query
@classmethod
@classmethod
def
_post
(
self
,
url
:
str
,
def
_post
(
cls
,
url
:
str
,
query
:
Dict
)
->
Any
:
query
:
Dict
)
->
Any
:
"""POST请求
"""POST请求
"""
"""
...
@@ -135,51 +133,55 @@ class ChatGLMSerLLM(LLM):
...
@@ -135,51 +133,55 @@ class ChatGLMSerLLM(LLM):
headers
=
_headers
,
headers
=
_headers
,
timeout
=
300
)
timeout
=
300
)
return
resp
return
resp
async
def
_post_stream
(
self
,
url
:
str
,
query
:
Dict
,
@staticmethod
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
async
def
_post_stream
(
url
:
str
,
query
:
Dict
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
"""POST请求
"""POST请求
"""
"""
_headers
=
{
"Content_Type"
:
"application/json"
}
_headers
=
{
"Content_Type"
:
"application/json"
}
async
with
aiohttp
.
ClientSession
()
as
sess
:
async
with
aiohttp
.
ClientSession
()
as
sess
:
async
with
sess
.
post
(
url
,
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
as
response
:
async
with
sess
.
post
(
url
,
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
as
response
:
if
response
.
status
==
200
:
if
response
.
status
==
200
:
if
stream
and
not
run_manager
:
if
stream
and
not
run_manager
:
print
(
'not callable'
)
print
(
'not callable'
)
if
run_manager
:
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_start
(
None
,
None
)
await
_callable
.
on_llm_start
(
None
,
None
)
async
for
chunk
in
response
.
content
.
iter_any
():
async
for
chunk
in
response
.
content
.
iter_any
():
# 处理每个块的数据
# 处理每个块的数据
if
chunk
and
run_manager
:
if
chunk
and
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
# print(chunk.decode("utf-8"),end="")
# print(chunk.decode("utf-8"),end="")
await
callable
.
on_llm_new_token
(
chunk
.
decode
(
"utf-8"
))
await
_
callable
.
on_llm_new_token
(
chunk
.
decode
(
"utf-8"
))
if
run_manager
:
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_end
(
None
)
await
_
callable
.
on_llm_end
(
None
)
else
:
else
:
raise
ValueError
(
f
'glm 请求异常,http code:{response.status}'
)
raise
ValueError
(
f
'glm 请求异常,http code:{response.status}'
)
def
_call
(
self
,
prompt
:
str
,
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
,
stream
=
False
,
**
kwargs
:
Any
)
->
str
:
**
kwargs
:
Any
)
->
str
:
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
# display("==============================")
# display("==============================")
# display(query)
# display(query)
# post
# post
if
stream
or
self
.
out_stream
:
if
stream
or
self
.
out_stream
:
async
def
_post_stream
():
async
def
_post_stream
():
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
stream
or
self
.
out_stream
)
query
=
query
,
run_manager
=
run_manager
,
stream
=
stream
or
self
.
out_stream
)
asyncio
.
run
(
_post_stream
())
asyncio
.
run
(
_post_stream
())
return
''
return
''
else
:
else
:
resp
=
self
.
_post
(
url
=
self
.
url
,
resp
=
self
.
_post
(
url
=
self
.
url
,
query
=
query
)
query
=
query
)
if
resp
.
status_code
==
200
:
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
resp_json
=
resp
.
json
()
...
@@ -189,18 +191,19 @@ class ChatGLMSerLLM(LLM):
...
@@ -189,18 +191,19 @@ class ChatGLMSerLLM(LLM):
return
predictions
return
predictions
else
:
else
:
raise
ValueError
(
f
'glm 请求异常,http code:{resp.status_code}'
)
raise
ValueError
(
f
'glm 请求异常,http code:{resp.status_code}'
)
async
def
_acall
(
async
def
_acall
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
str
:
)
->
str
:
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
query
=
self
.
_construct_query
(
prompt
=
prompt
,
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
query
=
query
,
run_manager
=
run_manager
,
stream
=
self
.
out_stream
)
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
return
''
query
=
query
,
run_manager
=
run_manager
,
stream
=
self
.
out_stream
)
return
''
@property
@property
def
_identifying_params
(
self
)
->
Mapping
[
str
,
Any
]:
def
_identifying_params
(
self
)
->
Mapping
[
str
,
Any
]:
...
@@ -209,4 +212,4 @@ class ChatGLMSerLLM(LLM):
...
@@ -209,4 +212,4 @@ class ChatGLMSerLLM(LLM):
_param_dict
=
{
_param_dict
=
{
"url"
:
self
.
url
"url"
:
self
.
url
}
}
return
_param_dict
return
_param_dict
\ No newline at end of file
src/llm/chatglm_openapi.py
View file @
493cdd59
import
os
import
os
import
requests
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
pydantic
import
root_validator
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
langchain
import
langchain
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain_openai
import
OpenAI
from
langchain_openai
import
OpenAI
from
langchain.cache
import
InMemoryCache
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
class
ChatGLMSerLLM
(
OpenAI
):
class
ChatGLMSerLLM
(
OpenAI
):
def
get_token_ids
(
self
,
text
:
str
)
->
List
[
int
]:
def
get_token_ids
(
self
,
text
:
str
)
->
List
[
int
]:
if
self
.
model_name
.
__contains__
(
"chatglm"
):
if
self
.
model_name
.
__contains__
(
"chatglm"
):
## 发起http请求,获取token_ids
## 发起http请求,获取token_ids
url
=
f
"{self.openai_api_base}/num_tokens"
url
=
f
"{self.openai_api_base}/num_tokens"
query
=
{
"prompt"
:
text
,
"model"
:
self
.
model_name
}
query
=
{
"prompt"
:
text
,
"model"
:
self
.
model_name
}
_headers
=
{
"Content_Type"
:
"application/json"
,
"Authorization"
:
"chatglm "
+
self
.
openai_api_key
}
_headers
=
{
"Content_Type"
:
"application/json"
,
"Authorization"
:
"chatglm "
+
self
.
openai_api_key
}
resp
=
self
.
_post
(
url
=
url
,
query
=
query
,
headers
=
_headers
)
resp
=
self
.
_post
(
url
=
url
,
query
=
query
,
headers
=
_headers
)
if
resp
.
status_code
==
200
:
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
resp_json
=
resp
.
json
()
print
(
resp_json
)
print
(
resp_json
)
...
@@ -30,10 +31,10 @@ class ChatGLMSerLLM(OpenAI):
...
@@ -30,10 +31,10 @@ class ChatGLMSerLLM(OpenAI):
## predictions字符串转int
## predictions字符串转int
return
[
int
(
predictions
)]
return
[
int
(
predictions
)]
return
[
len
(
text
)]
return
[
len
(
text
)]
@classmethod
@classmethod
def
_post
(
self
,
url
:
str
,
def
_post
(
cls
,
url
:
str
,
query
:
Dict
,
headers
:
Dict
)
->
Any
:
query
:
Dict
,
headers
:
Dict
)
->
Any
:
"""POST请求
"""POST请求
"""
"""
_headers
=
{
"Content_Type"
:
"application/json"
}
_headers
=
{
"Content_Type"
:
"application/json"
}
...
@@ -43,4 +44,4 @@ class ChatGLMSerLLM(OpenAI):
...
@@ -43,4 +44,4 @@ class ChatGLMSerLLM(OpenAI):
json
=
query
,
json
=
query
,
headers
=
_headers
,
headers
=
_headers
,
timeout
=
300
)
timeout
=
300
)
return
resp
return
resp
\ No newline at end of file
src/llm/ernie.py
View file @
493cdd59
...
@@ -2,7 +2,7 @@ import logging
...
@@ -2,7 +2,7 @@ import logging
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
langchain.utils
import
get_from_dict_or_env
from
langchain.utils
import
get_from_dict_or_env
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
...
@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m
...
@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
ModelType
(
Enum
):
class
ModelType
(
Enum
):
ERNIE
=
"ernie"
ERNIE
=
"ernie"
ERNIE_LITE
=
"ernie-lite"
ERNIE_LITE
=
"ernie-lite"
...
@@ -25,7 +26,7 @@ class ModelType(Enum):
...
@@ -25,7 +26,7 @@ class ModelType(Enum):
LLAMA2_13B
=
"llama2-13b"
LLAMA2_13B
=
"llama2-13b"
LLAMA2_70B
=
"llama2-70b"
LLAMA2_70B
=
"llama2-70b"
QFCN_LLAMA2_7B
=
"qfcn-llama2-7b"
QFCN_LLAMA2_7B
=
"qfcn-llama2-7b"
BLOOMZ_7B
=
"bloomz-7b"
BLOOMZ_7B
=
"bloomz-7b"
MODEL_SERVICE_BASE_URL
=
"https://aip.baidubce.com/rpc/2.0/"
MODEL_SERVICE_BASE_URL
=
"https://aip.baidubce.com/rpc/2.0/"
...
@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = {
...
@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = {
ModelType
.
BLOOMZ_7B
:
"ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
,
ModelType
.
BLOOMZ_7B
:
"ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
,
}
}
class
ErnieLLM
(
LLM
):
class
ErnieLLM
(
LLM
):
"""
"""
ErnieLLM is a LLM that uses Ernie to generate text.
ErnieLLM is a LLM that uses Ernie to generate text.
...
@@ -52,27 +54,23 @@ class ErnieLLM(LLM):
...
@@ -52,27 +54,23 @@ class ErnieLLM(LLM):
access_token
:
Optional
[
str
]
=
""
access_token
:
Optional
[
str
]
=
""
@root_validator
()
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
"""Validate the environment."""
# print(values)
# print(values)
model_name
=
ModelType
(
get_from_dict_or_env
(
values
,
"model_name"
,
"model_name"
,
str
(
ModelType
.
ERNIE
)))
model_name
=
ModelType
(
get_from_dict_or_env
(
values
,
"model_name"
,
"model_name"
,
str
(
ModelType
.
ERNIE
)))
access_token
=
get_from_dict_or_env
(
values
,
"access_token"
,
"ERNIE_ACCESS_TOKEN"
,
""
)
access_token
=
get_from_dict_or_env
(
values
,
"access_token"
,
"ERNIE_ACCESS_TOKEN"
,
""
)
if
not
access_token
:
if
not
access_token
:
raise
ValueError
(
"No access token provided."
)
raise
ValueError
(
"No access token provided."
)
values
[
"model_name"
]
=
model_name
values
[
"model_name"
]
=
model_name
values
[
"access_token"
]
=
access_token
values
[
"access_token"
]
=
access_token
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
request
=
CompletionRequest
(
messages
=
[
Message
(
"user"
,
prompt
)])
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
request
=
CompletionRequest
(
messages
=
[
Message
(
"user"
,
prompt
)])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
or
""
,
request
)
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
or
""
,
request
)
try
:
try
:
# 你的代码
# 你的代码
...
@@ -81,9 +79,8 @@ class ErnieLLM(LLM):
...
@@ -81,9 +79,8 @@ class ErnieLLM(LLM):
return
response
return
response
except
Exception
as
e
:
except
Exception
as
e
:
# 处理异常
# 处理异常
print
(
"exception:"
,
e
)
print
(
"exception:"
,
e
)
return
e
.
__str__
()
return
e
.
__str__
()
@property
@property
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
...
@@ -94,28 +91,25 @@ class ErnieLLM(LLM):
...
@@ -94,28 +91,25 @@ class ErnieLLM(LLM):
# return {
# return {
# "name": "ernie",
# "name": "ernie",
# }
# }
def
_get_model_service_url
(
model_name
)
->
str
:
def
_get_model_service_url
(
model_name
)
->
str
:
# print("_get_model_service_url model_name: ",model_name)
# print("_get_model_service_url model_name: ",model_name)
return
MODEL_SERVICE_BASE_URL
+
MODEL_SERVICE_Suffix
[
model_name
]
return
MODEL_SERVICE_BASE_URL
+
MODEL_SERVICE_Suffix
[
model_name
]
class
ErnieChat
(
LLM
):
class
ErnieChat
(
LLM
):
model_name
:
ModelType
model_name
:
ModelType
access_token
:
str
access_token
:
str
prefix_messages
:
List
=
Field
(
default_factory
=
list
)
prefix_messages
:
List
=
Field
(
default_factory
=
list
)
id
:
str
=
""
id
:
str
=
""
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
msg
=
user_message
(
prompt
)
msg
=
user_message
(
prompt
)
request
=
CompletionRequest
(
messages
=
self
.
prefix_messages
+
[
msg
])
request
=
CompletionRequest
(
messages
=
self
.
prefix_messages
+
[
msg
])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
,
request
)
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
,
request
)
try
:
try
:
# 你的代码
# 你的代码
response
=
bot
.
get_response
()
.
result
response
=
bot
.
get_response
()
.
result
...
@@ -127,11 +121,11 @@ class ErnieChat(LLM):
...
@@ -127,11 +121,11 @@ class ErnieChat(LLM):
except
Exception
as
e
:
except
Exception
as
e
:
# 处理异常
# 处理异常
raise
e
raise
e
def
_get_id
(
self
)
->
str
:
def
_get_id
(
self
)
->
str
:
return
self
.
id
return
self
.
id
@property
@property
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
"""Return type of llm."""
return
"ernie"
return
"ernie"
\ No newline at end of file
src/llm/ernie_sdk.py
View file @
493cdd59
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
typing
import
List
from
typing
import
List
...
@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field
...
@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field
from
enum
import
Enum
from
enum
import
Enum
class
MessageRole
(
str
,
Enum
):
class
MessageRole
(
str
,
Enum
):
USER
=
"user"
USER
=
"user"
BOT
=
"assistant"
BOT
=
"assistant"
@dataclass
@dataclass
class
Message
:
class
Message
:
role
:
str
role
:
str
content
:
str
content
:
str
@dataclass
@dataclass
class
CompletionRequest
:
class
CompletionRequest
:
messages
:
List
[
Message
]
messages
:
List
[
Message
]
stream
:
bool
=
False
stream
:
bool
=
False
user
:
str
=
""
user
:
str
=
""
@dataclass
@dataclass
class
Usage
:
class
Usage
:
prompt_tokens
:
int
prompt_tokens
:
int
completion_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
total_tokens
:
int
@dataclass
@dataclass
class
CompletionResponse
:
class
CompletionResponse
:
id
:
str
id
:
str
...
@@ -42,12 +45,14 @@ class CompletionResponse:
...
@@ -42,12 +45,14 @@ class CompletionResponse:
is_safe
:
bool
=
False
is_safe
:
bool
=
False
is_truncated
:
bool
=
False
is_truncated
:
bool
=
False
class
ErrorResponse
(
BaseModel
):
class
ErrorResponse
(
BaseModel
):
error_code
:
int
=
Field
(
...
)
error_code
:
int
=
Field
(
...
)
error_msg
:
str
=
Field
(
...
)
error_msg
:
str
=
Field
(
...
)
id
:
str
=
Field
(
...
)
id
:
str
=
Field
(
...
)
class
ErnieBot
():
class
ErnieBot
:
url
:
str
url
:
str
access_token
:
str
access_token
:
str
request
:
CompletionRequest
request
:
CompletionRequest
...
@@ -64,17 +69,19 @@ class ErnieBot():
...
@@ -64,17 +69,19 @@ class ErnieBot():
headers
=
{
'Content-Type'
:
'application/json'
}
headers
=
{
'Content-Type'
:
'application/json'
}
params
=
{
'access_token'
:
self
.
access_token
}
params
=
{
'access_token'
:
self
.
access_token
}
request_dict
=
asdict
(
self
.
request
)
request_dict
=
asdict
(
self
.
request
)
response
=
requests
.
post
(
self
.
url
,
params
=
params
,
data
=
json
.
dumps
(
request_dict
),
headers
=
headers
)
response
=
requests
.
post
(
self
.
url
,
params
=
params
,
data
=
json
.
dumps
(
request_dict
),
headers
=
headers
)
# print(response.json())
# print(response.json())
try
:
try
:
return
CompletionResponse
(
**
response
.
json
())
return
CompletionResponse
(
**
response
.
json
())
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
raise
Exception
(
response
.
json
())
raise
Exception
(
response
.
json
())
def
user_message
(
prompt
:
str
)
->
Message
:
def
user_message
(
prompt
:
str
)
->
Message
:
return
Message
(
MessageRole
.
USER
,
prompt
)
return
Message
(
MessageRole
.
USER
,
prompt
)
def
bot_message
(
prompt
:
str
)
->
Message
:
def
bot_message
(
prompt
:
str
)
->
Message
:
return
Message
(
MessageRole
.
BOT
,
prompt
)
return
Message
(
MessageRole
.
BOT
,
prompt
)
\ No newline at end of file
src/llm/ernie_with_sdk.py
View file @
493cdd59
import
os
import
os
import
requests
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
pydantic
import
root_validator
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
qianfan
import
qianfan
from
qianfan
import
ChatCompletion
from
qianfan
import
ChatCompletion
# 启动llm的缓存
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
# langchain.llm_cache = InMemoryCache()
class
ChatERNIESerLLM
(
LLM
):
class
ChatERNIESerLLM
(
LLM
):
# 模型服务url
# 模型服务url
chat_completion
:
ChatCompletion
=
None
chat_completion
:
ChatCompletion
=
None
# url: str = "http://127.0.0.1:8000"
# url: str = "http://127.0.0.1:8000"
chat_history
:
dict
=
[]
chat_history
:
dict
=
[]
out_stream
:
bool
=
False
out_stream
:
bool
=
False
cache
:
bool
=
False
cache
:
bool
=
False
model_name
:
str
=
"ERNIE-Bot"
model_name
:
str
=
"ERNIE-Bot"
# def __init__(self):
# def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
@property
@property
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
return
self
.
model_name
return
self
.
model_name
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
return
len
(
text
)
return
len
(
text
)
def
convert_data
(
self
,
data
):
@staticmethod
def
convert_data
(
data
):
result
=
[]
result
=
[]
for
item
in
data
:
for
item
in
data
:
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
return
result
return
result
def
_call
(
self
,
prompt
:
str
,
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
,
stream
=
False
,
**
kwargs
:
Any
)
->
str
:
**
kwargs
:
Any
)
->
str
:
resp
=
self
.
chat_completion
.
do
(
model
=
self
.
model_name
,
messages
=
[{
resp
=
self
.
chat_completion
.
do
(
model
=
self
.
model_name
,
messages
=
[{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
prompt
"content"
:
prompt
}])
}])
...
@@ -54,26 +54,26 @@ class ChatERNIESerLLM(LLM):
...
@@ -54,26 +54,26 @@ class ChatERNIESerLLM(LLM):
return
resp
.
body
[
"result"
]
return
resp
.
body
[
"result"
]
async
def
_post_stream
(
self
,
async
def
_post_stream
(
self
,
query
:
Dict
,
query
:
Dict
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
stream
=
False
)
->
Any
:
"""POST请求
"""POST请求
"""
"""
async
for
r
in
await
self
.
chat_completion
.
ado
(
model
=
self
.
model_name
,
messages
=
[
query
],
stream
=
stream
):
async
for
r
in
await
self
.
chat_completion
.
ado
(
model
=
self
.
model_name
,
messages
=
[
query
],
stream
=
stream
):
assert
r
.
code
==
200
assert
r
.
code
==
200
if
run_manager
:
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_new_token
(
r
.
body
[
"result"
])
await
_callable
.
on_llm_new_token
(
r
.
body
[
"result"
])
async
def
_acall
(
async
def
_acall
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
str
:
)
->
str
:
await
self
.
_post_stream
(
query
=
{
await
self
.
_post_stream
(
query
=
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
prompt
"content"
:
prompt
},
stream
=
True
,
run_manager
=
run_manager
)
},
stream
=
True
,
run_manager
=
run_manager
)
return
''
return
''
\ No newline at end of file
src/llm/loader.py
View file @
493cdd59
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
transformers
import
AutoModel
,
AutoTokenizer
,
AutoConfig
,
DataCollatorForSeq2Seq
from
transformers
import
AutoModel
,
AutoTokenizer
,
AutoConfig
,
DataCollatorForSeq2Seq
from
peft
import
PeftModel
from
peft
import
PeftModel
class
ModelLoader
:
class
ModelLoader
:
def
__init__
(
self
,
model_name_or_path
,
pre_seq_len
=
0
,
prefix_projection
=
False
):
def
__init__
(
self
,
model_name_or_path
,
pre_seq_len
=
0
,
prefix_projection
=
False
):
self
.
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
self
.
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
...
@@ -23,18 +24,19 @@ class ModelLoader:
...
@@ -23,18 +24,19 @@ class ModelLoader:
def
models
(
self
):
def
models
(
self
):
return
self
.
model
,
self
.
tokenizer
return
self
.
model
,
self
.
tokenizer
def
collator
(
self
):
def
collator
(
self
):
return
DataCollatorForSeq2Seq
(
tokenizer
=
self
.
tokenizer
,
model
=
self
.
model
)
return
DataCollatorForSeq2Seq
(
tokenizer
=
self
.
tokenizer
,
model
=
self
.
model
)
def
load_lora
(
self
,
ckpt_path
,
name
=
"default"
):
def
load_lora
(
self
,
ckpt_path
,
name
=
"default"
):
#训练时节约GPU占用
#
训练时节约GPU占用
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
ckpt_path
,
adapter_name
=
name
)
_peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
ckpt_path
,
adapter_name
=
name
)
self
.
model
=
peft_loaded
.
merge_and_unload
()
self
.
model
=
_
peft_loaded
.
merge_and_unload
()
print
(
f
"Load LoRA model successfully!"
)
print
(
f
"Load LoRA model successfully!"
)
def
load_loras
(
self
,
ckpt_paths
,
name
=
"default"
):
def
load_loras
(
self
,
ckpt_paths
,
name
=
"default"
):
if
len
(
ckpt_paths
)
==
0
:
global
peft_loaded
if
len
(
ckpt_paths
)
==
0
:
return
return
first
=
True
first
=
True
for
name
,
path
in
ckpt_paths
.
items
():
for
name
,
path
in
ckpt_paths
.
items
():
...
@@ -42,12 +44,12 @@ class ModelLoader:
...
@@ -42,12 +44,12 @@ class ModelLoader:
if
first
:
if
first
:
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
path
,
adapter_name
=
name
)
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
path
,
adapter_name
=
name
)
first
=
False
first
=
False
else
:
else
:
peft_loaded
.
load_adapter
(
path
,
adapter_name
=
name
)
peft_loaded
.
load_adapter
(
path
,
adapter_name
=
name
)
peft_loaded
.
set_adapter
(
name
)
peft_loaded
.
set_adapter
(
name
)
self
.
model
=
peft_loaded
self
.
model
=
peft_loaded
def
load_prefix
(
self
,
ckpt_path
):
def
load_prefix
(
self
,
ckpt_path
):
prefix_state_dict
=
torch
.
load
(
os
.
path
.
join
(
ckpt_path
,
"pytorch_model.bin"
))
prefix_state_dict
=
torch
.
load
(
os
.
path
.
join
(
ckpt_path
,
"pytorch_model.bin"
))
new_prefix_state_dict
=
{}
new_prefix_state_dict
=
{}
for
k
,
v
in
prefix_state_dict
.
items
():
for
k
,
v
in
prefix_state_dict
.
items
():
...
@@ -56,4 +58,3 @@ class ModelLoader:
...
@@ -56,4 +58,3 @@ class ModelLoader:
self
.
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
self
.
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
self
.
model
.
transformer
.
prefix_encoder
.
float
()
self
.
model
.
transformer
.
prefix_encoder
.
float
()
print
(
f
"Load prefix model successfully!"
)
print
(
f
"Load prefix model successfully!"
)
src/llm/spark.py
View file @
493cdd59
...
@@ -2,7 +2,7 @@ import logging
...
@@ -2,7 +2,7 @@ import logging
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
langchain.utils
import
get_from_dict_or_env
from
langchain.utils
import
get_from_dict_or_env
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
...
@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI
...
@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
text
=
[]
text
=
[]
# length = 0
# length = 0
def
getText
(
role
,
content
):
def
getText
(
role
,
content
):
jsoncon
=
{}
jsoncon
=
{
"role"
:
role
,
"content"
:
content
}
jsoncon
[
"role"
]
=
role
jsoncon
[
"content"
]
=
content
text
.
append
(
jsoncon
)
text
.
append
(
jsoncon
)
return
text
return
text
def
getlength
(
text
):
def
getlength
(
text
):
length
=
0
length
=
0
for
content
in
text
:
for
content
in
text
:
...
@@ -36,11 +35,12 @@ def getlength(text):
...
@@ -36,11 +35,12 @@ def getlength(text):
length
+=
leng
length
+=
leng
return
length
return
length
def
checklen
(
text
):
while
(
getlength
(
text
)
>
8000
):
def
checklen
(
_text
):
del
text
[
0
]
while
getlength
(
_text
)
>
8000
:
return
text
del
_text
[
0
]
return
_text
class
SparkLLM
(
LLM
):
class
SparkLLM
(
LLM
):
"""
"""
...
@@ -62,16 +62,16 @@ class SparkLLM(LLM):
...
@@ -62,16 +62,16 @@ class SparkLLM(LLM):
None
,
None
,
description
=
"version"
,
description
=
"version"
,
)
)
api
:
SparkAPI
=
Field
(
api
:
SparkAPI
=
Field
(
None
,
None
,
description
=
"api"
,
description
=
"api"
,
)
)
@root_validator
()
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
"""Validate the environment."""
# print(values)
# print(values)
appid
=
get_from_dict_or_env
(
values
,
"appid"
,
"XH_APPID"
,
""
)
appid
=
get_from_dict_or_env
(
values
,
"appid"
,
"XH_APPID"
,
""
)
api_key
=
get_from_dict_or_env
(
values
,
"api_key"
,
"XH_API_KEY"
,
""
)
api_key
=
get_from_dict_or_env
(
values
,
"api_key"
,
"XH_API_KEY"
,
""
)
api_secret
=
get_from_dict_or_env
(
values
,
"api_secret"
,
"XH_API_SECRET"
,
""
)
api_secret
=
get_from_dict_or_env
(
values
,
"api_secret"
,
"XH_API_SECRET"
,
""
)
...
@@ -84,23 +84,19 @@ class SparkLLM(LLM):
...
@@ -84,23 +84,19 @@ class SparkLLM(LLM):
raise
ValueError
(
"No api_key provided."
)
raise
ValueError
(
"No api_key provided."
)
if
not
api_secret
:
if
not
api_secret
:
raise
ValueError
(
"No api_secret provided."
)
raise
ValueError
(
"No api_secret provided."
)
values
[
"appid"
]
=
appid
values
[
"appid"
]
=
appid
values
[
"api_key"
]
=
api_key
values
[
"api_key"
]
=
api_key
values
[
"api_secret"
]
=
api_secret
values
[
"api_secret"
]
=
api_secret
api
=
SparkAPI
(
appid
,
api_key
,
api_secret
,
version
)
api
=
SparkAPI
(
appid
,
api_key
,
api_secret
,
version
)
values
[
"api"
]
=
api
values
[
"api"
]
=
api
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
question
=
self
.
getText
(
"user"
,
prompt
)
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
question
=
self
.
getText
(
"user"
,
prompt
)
try
:
try
:
# 你的代码
# 你的代码
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
...
@@ -109,20 +105,18 @@ class SparkLLM(LLM):
...
@@ -109,20 +105,18 @@ class SparkLLM(LLM):
return
response
return
response
except
Exception
as
e
:
except
Exception
as
e
:
# 处理异常
# 处理异常
print
(
"exception:"
,
e
)
print
(
"exception:"
,
e
)
raise
e
raise
e
def
getText
(
self
,
role
,
content
):
def
getText
(
self
,
role
,
content
):
text
=
[]
text
=
[]
jsoncon
=
{}
jsoncon
=
{}
jsoncon
[
"role"
]
=
role
jsoncon
[
"role"
]
=
role
jsoncon
[
"content"
]
=
content
jsoncon
[
"content"
]
=
content
text
.
append
(
jsoncon
)
text
.
append
(
jsoncon
)
return
text
return
text
@property
@property
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
"""Return type of llm."""
return
"xinghuo"
return
"xinghuo"
\ No newline at end of file
src/llm/wrapper.py
View file @
493cdd59
from
langchain.llms.base
import
LLM
from
langchain.llms.base
import
LLM
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
pydantic
import
root_validator
from
pydantic
import
root_validator
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
class
WrapperLLM
(
LLM
):
class
WrapperLLM
(
LLM
):
tokenizer
:
PreTrainedTokenizer
=
None
tokenizer
:
PreTrainedTokenizer
=
None
model
:
PreTrainedModel
=
None
model
:
PreTrainedModel
=
None
@root_validator
()
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
"""Validate the environment."""
# print(values)
# print(values)
if
values
.
get
(
"model"
)
is
None
:
if
values
.
get
(
"model"
)
is
None
:
...
@@ -19,16 +19,12 @@ class WrapperLLM(LLM):
...
@@ -19,16 +19,12 @@ class WrapperLLM(LLM):
raise
ValueError
(
"No tokenizer provided."
)
raise
ValueError
(
"No tokenizer provided."
)
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
return
resp
return
resp
@property
@property
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
"""Return type of llm."""
return
"wrapper"
return
"wrapper"
\ No newline at end of file
src/loader/callback.py
View file @
493cdd59
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
class
BaseCallback
(
ABC
):
class
BaseCallback
(
ABC
):
@abstractmethod
@abstractmethod
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
#return True舍弃当前段落
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
# return True舍弃当前段落
pass
pass
\ No newline at end of file
src/loader/chinese_text_splitter.py
View file @
493cdd59
...
@@ -25,7 +25,7 @@ class ChineseTextSplitter(CharacterTextSplitter):
...
@@ -25,7 +25,7 @@ class ChineseTextSplitter(CharacterTextSplitter):
sent_list
.
append
(
ele
)
sent_list
.
append
(
ele
)
return
sent_list
return
sent_list
def
split_text
(
self
,
text
:
str
)
->
List
[
str
]:
##此处需要进一步优化逻辑
def
split_text
(
self
,
text
:
str
)
->
List
[
str
]:
##此处需要进一步优化逻辑
if
self
.
pdf
:
if
self
.
pdf
:
text
=
re
.
sub
(
r"\n{3,}"
,
r"\n"
,
text
)
text
=
re
.
sub
(
r"\n{3,}"
,
r"\n"
,
text
)
text
=
re
.
sub
(
'
\
s'
,
" "
,
text
)
text
=
re
.
sub
(
'
\
s'
,
" "
,
text
)
...
@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter):
...
@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter):
ele_id
=
ele1_ls
.
index
(
ele_ele1
)
ele_id
=
ele1_ls
.
index
(
ele_ele1
)
ele1_ls
=
ele1_ls
[:
ele_id
]
+
[
i
for
i
in
ele2_ls
if
i
]
+
ele1_ls
[
ele_id
+
1
:]
ele1_ls
=
ele1_ls
[:
ele_id
]
+
[
i
for
i
in
ele2_ls
if
i
]
+
ele1_ls
[
ele_id
+
1
:]
id
=
ls
.
index
(
ele
)
_
id
=
ls
.
index
(
ele
)
ls
=
ls
[:
id
]
+
[
i
for
i
in
ele1_ls
if
i
]
+
ls
[
id
+
1
:]
ls
=
ls
[:
_id
]
+
[
i
for
i
in
ele1_ls
if
i
]
+
ls
[
_
id
+
1
:]
return
ls
return
ls
src/loader/config.py
View file @
493cdd59
# 文本分句长度
# 文本分句长度
SENTENCE_SIZE
=
100
SENTENCE_SIZE
=
100
ZH_TITLE_ENHANCE
=
False
ZH_TITLE_ENHANCE
=
False
\ No newline at end of file
src/loader/load.py
View file @
493cdd59
This diff is collapsed.
Click to expand it.
src/loader/zh_title_enhance.py
View file @
493cdd59
...
@@ -33,7 +33,7 @@ def is_possible_title(
...
@@ -33,7 +33,7 @@ def is_possible_title(
title_max_word_length
:
int
=
20
,
title_max_word_length
:
int
=
20
,
non_alpha_threshold
:
float
=
0.5
,
non_alpha_threshold
:
float
=
0.5
,
)
->
bool
:
)
->
bool
:
"""Checks to see if the text passes all
of
the checks for a valid title.
"""Checks to see if the text passes all the checks for a valid title.
Parameters
Parameters
----------
----------
...
...
src/pgdb/chat/c_db.py
View file @
493cdd59
import
psycopg2
import
psycopg2
from
psycopg2
import
OperationalError
,
InterfaceError
from
psycopg2
import
OperationalError
,
InterfaceError
class
UPostgresDB
:
class
UPostgresDB
:
'''
"""
psycopg2.connect(
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
host #指定连接数据库的主机名。
...
@@ -18,8 +19,9 @@ class UPostgresDB:
...
@@ -18,8 +19,9 @@ class UPostgresDB:
sslkey #指定私钥文件名。
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
sslcert #指定公钥文件名。
)
)
'''
"""
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
self
.
host
=
host
self
.
host
=
host
self
.
database
=
database
self
.
database
=
database
self
.
user
=
user
self
.
user
=
user
...
@@ -35,7 +37,7 @@ class UPostgresDB:
...
@@ -35,7 +37,7 @@ class UPostgresDB:
database
=
self
.
database
,
database
=
self
.
database
,
user
=
self
.
user
,
user
=
self
.
user
,
password
=
self
.
password
,
password
=
self
.
password
,
port
=
self
.
port
port
=
self
.
port
)
)
self
.
cur
=
self
.
conn
.
cursor
()
self
.
cur
=
self
.
conn
.
cursor
()
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -45,7 +47,7 @@ class UPostgresDB:
...
@@ -45,7 +47,7 @@ class UPostgresDB:
try
:
try
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
self
.
connect
()
self
.
cur
.
execute
(
query
)
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
except
InterfaceError
as
e
:
except
InterfaceError
as
e
:
print
(
f
"数据库连接已经关闭: {e}"
)
print
(
f
"数据库连接已经关闭: {e}"
)
...
@@ -53,8 +55,8 @@ class UPostgresDB:
...
@@ -53,8 +55,8 @@ class UPostgresDB:
print
(
f
"数据库连接出现问题: {e}"
)
print
(
f
"数据库连接出现问题: {e}"
)
self
.
connect
()
self
.
connect
()
self
.
retry_execute
(
query
)
self
.
retry_execute
(
query
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"执行sql语句出现错误: {e}"
)
print
(
f
"执行sql语句出现错误: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
def
retry_execute
(
self
,
query
):
def
retry_execute
(
self
,
query
):
...
@@ -69,16 +71,16 @@ class UPostgresDB:
...
@@ -69,16 +71,16 @@ class UPostgresDB:
try
:
try
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
self
.
connect
()
self
.
cur
.
execute
(
query
,
args
)
self
.
cur
.
execute
(
query
,
args
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
except
InterfaceError
as
e
:
except
InterfaceError
as
e
:
print
(
f
"数据库连接已经关闭: {e}"
)
print
(
f
"数据库连接已经关闭: {e}"
)
except
OperationalError
as
e
:
except
OperationalError
as
e
:
print
(
f
"数据库操作出现问题: {e}"
)
print
(
f
"数据库操作出现问题: {e}"
)
self
.
connect
()
self
.
connect
()
self
.
retry_execute_args
(
query
,
args
)
self
.
retry_execute_args
(
query
,
args
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"执行sql语句出现错误: {e}"
)
print
(
f
"执行sql语句出现错误: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
def
retry_execute_args
(
self
,
query
,
args
):
def
retry_execute_args
(
self
,
query
,
args
):
...
@@ -89,7 +91,6 @@ class UPostgresDB:
...
@@ -89,7 +91,6 @@ class UPostgresDB:
print
(
f
"重新执行sql语句再次出现错误: {type(e).__name__}: {e}"
)
print
(
f
"重新执行sql语句再次出现错误: {type(e).__name__}: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
def
search
(
self
,
query
,
params
=
None
):
def
search
(
self
,
query
,
params
=
None
):
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
self
.
connect
()
...
@@ -97,7 +98,7 @@ class UPostgresDB:
...
@@ -97,7 +98,7 @@ class UPostgresDB:
def
fetchall
(
self
):
def
fetchall
(
self
):
return
self
.
cur
.
fetchall
()
return
self
.
cur
.
fetchall
()
def
fetchone
(
self
):
def
fetchone
(
self
):
return
self
.
cur
.
fetchone
()
return
self
.
cur
.
fetchone
()
...
@@ -109,8 +110,8 @@ class UPostgresDB:
...
@@ -109,8 +110,8 @@ class UPostgresDB:
try
:
try
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
self
.
connect
()
self
.
cur
.
execute
(
query
)
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
src/pgdb/chat/c_user_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
from
.c_db
import
UPostgresDB
import
json
import
json
TABLE_USER
=
"""
TABLE_USER
=
"""
DROP TABLE IF EXISTS "c_user";
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
CREATE TABLE c_user (
...
@@ -13,14 +14,15 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码';
...
@@ -13,14 +14,15 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
COMMENT ON TABLE "c_user" IS '用户表';
"""
"""
class
CUser
:
class
CUser
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
self
.
db
=
db
def
insert
(
self
,
value
):
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO c_user(user_id, account, password) VALUES (
%
s,
%
s,
%
s)"
query
=
f
"INSERT INTO c_user(user_id, account, password) VALUES (
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
]
))
def
create_table
(
self
):
def
create_table
(
self
):
query
=
TABLE_USER
query
=
TABLE_USER
self
.
db
.
execute
(
query
)
self
.
db
.
execute
(
query
)
\ No newline at end of file
src/pgdb/chat/chat_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
from
.c_db
import
UPostgresDB
import
json
import
json
TABLE_CHAT
=
"""
TABLE_CHAT
=
"""
DROP TABLE IF EXISTS "chat";
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
CREATE TABLE chat (
...
@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
...
@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
COMMENT ON TABLE "chat" IS '会话信息表';
"""
"""
class
Chat
:
class
Chat
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
self
.
db
=
db
...
@@ -24,9 +26,9 @@ class Chat:
...
@@ -24,9 +26,9 @@ class Chat:
# 插入数据
# 插入数据
def
insert
(
self
,
value
):
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (
%
s,
%
s,
%
s,
%
s)"
query
=
f
"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (
%
s,
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
]
))
# 创建表
# 创建表
def
create_table
(
self
):
def
create_table
(
self
):
query
=
TABLE_CHAT
query
=
TABLE_CHAT
self
.
db
.
execute
(
query
)
self
.
db
.
execute
(
query
)
\ No newline at end of file
src/pgdb/chat/turn_qa_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
from
.c_db
import
UPostgresDB
import
json
import
json
TABLE_CHAT
=
"""
TABLE_CHAT
=
"""
DROP TABLE IF EXISTS "turn_qa";
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
CREATE TABLE turn_qa (
...
@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,
...
@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
"""
"""
class
TurnQa
:
class
TurnQa
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
self
.
db
=
db
...
@@ -28,9 +30,9 @@ class TurnQa:
...
@@ -28,9 +30,9 @@ class TurnQa:
# 插入数据
# 插入数据
def
insert
(
self
,
value
):
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (
%
s,
%
s,
%
s,
%
s,
%
s,
%
s)"
query
=
f
"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (
%
s,
%
s,
%
s,
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
],
value
[
4
],
value
[
5
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
],
value
[
4
],
value
[
5
]
))
# 创建表
# 创建表
def
create_table
(
self
):
def
create_table
(
self
):
query
=
TABLE_CHAT
query
=
TABLE_CHAT
self
.
db
.
execute
(
query
)
self
.
db
.
execute
(
query
)
\ No newline at end of file
src/pgdb/knowledge/callback.py
View file @
493cdd59
import
os
,
sys
import
os
,
sys
from
os
import
path
from
os
import
path
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
import
json
import
json
from
typing
import
List
,
Any
,
Tuple
,
Dict
from
typing
import
List
,
Any
,
Tuple
,
Dict
from
langchain.schema
import
Document
from
langchain.schema
import
Document
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
,
str2hash_base64
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
,
str2hash_base64
class
DocumentCallback
(
ABC
):
class
DocumentCallback
(
ABC
):
@abstractmethod
#向量库储存前文档处理--
@abstractmethod
#
向量库储存前文档处理--
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
pass
pass
@abstractmethod
#向量库查询后文档处理--用于结构建立
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
List
[
Tuple
[
Document
,
float
]]:
#向量库查询后文档处理
@abstractmethod
# 向量库查询后文档处理--用于结构建立
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
\
List
[
Tuple
[
Document
,
float
]]:
# 向量库查询后文档处理
pass
pass
class
DefaultDocumentCallback
(
DocumentCallback
):
class
DefaultDocumentCallback
(
DocumentCallback
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
output_doc
=
[]
output_doc
=
[]
for
doc
in
documents
:
for
doc
in
documents
:
if
"next_doc"
in
doc
.
metadata
:
if
"next_doc"
in
doc
.
metadata
:
...
@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback):
...
@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback):
doc
.
metadata
.
pop
(
"next_doc"
)
doc
.
metadata
.
pop
(
"next_doc"
)
output_doc
.
append
(
doc
)
output_doc
.
append
(
doc
)
return
output_doc
return
output_doc
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
List
[
Tuple
[
Document
,
float
]]:
#向量库查询后文档处理
output_doc
:
List
[
Tuple
[
Document
,
float
]]
=
[]
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
\
List
[
Tuple
[
Document
,
float
]]:
# 向量库查询后文档处理
output_doc
:
List
[
Tuple
[
Document
,
float
]]
=
[]
exist_hash
=
[]
exist_hash
=
[]
for
doc
,
score
in
documents
:
for
doc
,
score
in
documents
:
print
(
exist_hash
)
print
(
exist_hash
)
dochash
=
str2hash_base64
(
doc
.
page_content
)
dochash
=
str2hash_base64
(
doc
.
page_content
)
if
dochash
in
exist_hash
:
if
dochash
in
exist_hash
:
continue
continue
else
:
else
:
exist_hash
.
append
(
dochash
)
exist_hash
.
append
(
dochash
)
output_doc
.
append
((
doc
,
score
))
output_doc
.
append
((
doc
,
score
))
if
len
(
output_doc
)
>
number
:
if
len
(
output_doc
)
>
number
:
return
output_doc
return
output_doc
fordoc
=
doc
fordoc
=
doc
while
(
"next_hash"
in
fordoc
.
metadata
)
:
while
"next_hash"
in
fordoc
.
metadata
:
if
len
(
fordoc
.
metadata
[
"next_hash"
])
>
0
:
if
len
(
fordoc
.
metadata
[
"next_hash"
])
>
0
:
if
fordoc
.
metadata
[
"next_hash"
]
in
exist_hash
:
if
fordoc
.
metadata
[
"next_hash"
]
in
exist_hash
:
break
break
else
:
else
:
...
@@ -50,11 +54,11 @@ class DefaultDocumentCallback(DocumentCallback):
...
@@ -50,11 +54,11 @@ class DefaultDocumentCallback(DocumentCallback):
content
=
docstore
.
TXT_DOC
.
search
(
fordoc
.
metadata
[
"next_hash"
])
content
=
docstore
.
TXT_DOC
.
search
(
fordoc
.
metadata
[
"next_hash"
])
if
content
:
if
content
:
fordoc
=
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
fordoc
=
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
output_doc
.
append
((
fordoc
,
score
))
output_doc
.
append
((
fordoc
,
score
))
if
len
(
output_doc
)
>
number
:
if
len
(
output_doc
)
>
number
:
return
output_doc
return
output_doc
else
:
else
:
break
break
else
:
else
:
break
break
return
output_doc
return
output_doc
\ No newline at end of file
src/pgdb/knowledge/k_db.py
View file @
493cdd59
import
psycopg2
import
psycopg2
class
PostgresDB
:
class
PostgresDB
:
'''
"""
psycopg2.connect(
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
host #指定连接数据库的主机名。
...
@@ -17,8 +18,9 @@ class PostgresDB:
...
@@ -17,8 +18,9 @@ class PostgresDB:
sslkey #指定私钥文件名。
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
sslcert #指定公钥文件名。
)
)
'''
"""
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
self
.
host
=
host
self
.
host
=
host
self
.
database
=
database
self
.
database
=
database
self
.
user
=
user
self
.
user
=
user
...
@@ -28,28 +30,29 @@ class PostgresDB:
...
@@ -28,28 +30,29 @@ class PostgresDB:
self
.
cur
=
None
self
.
cur
=
None
def
connect
(
self
):
def
connect
(
self
):
self
.
conn
=
psycopg2
.
connect
(
self
.
conn
=
psycopg2
.
connect
(
host
=
self
.
host
,
host
=
self
.
host
,
database
=
self
.
database
,
database
=
self
.
database
,
user
=
self
.
user
,
user
=
self
.
user
,
password
=
self
.
password
,
password
=
self
.
password
,
port
=
self
.
port
port
=
self
.
port
)
)
self
.
cur
=
self
.
conn
.
cursor
()
self
.
cur
=
self
.
conn
.
cursor
()
def
execute
(
self
,
query
):
def
execute
(
self
,
query
):
try
:
try
:
self
.
cur
.
execute
(
query
)
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
def
execute_args
(
self
,
query
,
args
):
def
execute_args
(
self
,
query
,
args
):
try
:
try
:
self
.
cur
.
execute
(
query
,
args
)
self
.
cur
.
execute
(
query
,
args
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
def
search
(
self
,
query
,
params
=
None
):
def
search
(
self
,
query
,
params
=
None
):
...
@@ -64,8 +67,8 @@ class PostgresDB:
...
@@ -64,8 +67,8 @@ class PostgresDB:
def
format
(
self
,
query
):
def
format
(
self
,
query
):
try
:
try
:
self
.
cur
.
execute
(
query
)
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
self
.
conn
.
commit
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
src/pgdb/knowledge/pgsqldocstore.py
View file @
493cdd59
import
sys
import
sys
from
os
import
path
from
os
import
path
# 这里相当于把当前目录添加到pythonpath中
# 这里相当于把当前目录添加到pythonpath中
sys
.
path
.
append
(
path
.
dirname
(
path
.
abspath
(
__file__
)))
sys
.
path
.
append
(
path
.
dirname
(
path
.
abspath
(
__file__
)))
from
typing
import
List
,
Union
,
Dict
,
Optional
from
typing
import
List
,
Union
,
Dict
,
Optional
from
langchain.docstore.base
import
AddableMixin
,
Docstore
from
langchain.docstore.base
import
AddableMixin
,
Docstore
from
k_db
import
PostgresDB
from
k_db
import
PostgresDB
from
.txt_doc_table
import
TxtDoc
from
.txt_doc_table
import
TxtDoc
from
.vec_txt_table
import
TxtVector
from
.vec_txt_table
import
TxtVector
import
json
,
hashlib
,
base64
import
json
,
hashlib
,
base64
from
langchain.schema
import
Document
from
langchain.schema
import
Document
def
str2hash_base64
(
inp
:
str
)
->
str
:
def
str2hash_base64
(
input
:
str
)
->
str
:
# return f"%s" % hash(input)
# return f"%s" % hash(input)
return
base64
.
b64encode
(
hashlib
.
sha1
(
inp
ut
.
encode
())
.
digest
())
.
decode
()
return
base64
.
b64encode
(
hashlib
.
sha1
(
inp
.
encode
())
.
digest
())
.
decode
()
class
PgSqlDocstore
(
Docstore
,
AddableMixin
):
class
PgSqlDocstore
(
Docstore
,
AddableMixin
):
host
:
str
host
:
str
dbname
:
str
dbname
:
str
username
:
str
username
:
str
password
:
str
password
:
str
port
:
str
port
:
str
'''
'''
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
'''
'''
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
{
"host"
:
self
.
host
,
"dbname"
:
self
.
dbname
,
"username"
:
self
.
username
,
"password"
:
self
.
password
,
"port"
:
self
.
port
}
return
{
"host"
:
self
.
host
,
"dbname"
:
self
.
dbname
,
"username"
:
self
.
username
,
"password"
:
self
.
password
,
"port"
:
self
.
port
}
def
__setstate__
(
self
,
info
):
def
__setstate__
(
self
,
info
):
self
.
__init__
(
info
)
self
.
__init__
(
info
)
def
__init__
(
self
,
info
:
dict
,
reset
:
bool
=
False
):
def
__init__
(
self
,
info
:
dict
,
reset
:
bool
=
False
):
self
.
host
=
info
[
"host"
]
self
.
host
=
info
[
"host"
]
self
.
dbname
=
info
[
"dbname"
]
self
.
dbname
=
info
[
"dbname"
]
self
.
username
=
info
[
"username"
]
self
.
username
=
info
[
"username"
]
self
.
password
=
info
[
"password"
]
self
.
password
=
info
[
"password"
]
self
.
port
=
info
[
"port"
]
if
"port"
in
info
else
"5432"
;
self
.
port
=
info
[
"port"
]
if
"port"
in
info
else
"5432"
self
.
pgdb
=
PostgresDB
(
self
.
host
,
self
.
dbname
,
self
.
username
,
self
.
password
,
port
=
self
.
port
)
self
.
pgdb
=
PostgresDB
(
self
.
host
,
self
.
dbname
,
self
.
username
,
self
.
password
,
port
=
self
.
port
)
self
.
TXT_DOC
=
TxtDoc
(
self
.
pgdb
)
self
.
TXT_DOC
=
TxtDoc
(
self
.
pgdb
)
self
.
VEC_TXT
=
TxtVector
(
self
.
pgdb
)
self
.
VEC_TXT
=
TxtVector
(
self
.
pgdb
)
if
reset
:
if
reset
:
...
@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin):
...
@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin):
self
.
VEC_TXT
.
drop_table
()
self
.
VEC_TXT
.
drop_table
()
self
.
TXT_DOC
.
create_table
()
self
.
TXT_DOC
.
create_table
()
self
.
VEC_TXT
.
create_table
()
self
.
VEC_TXT
.
create_table
()
def
__sub_init__
(
self
):
def
__sub_init__
(
self
):
if
not
self
.
pgdb
.
conn
:
if
not
self
.
pgdb
.
conn
:
self
.
pgdb
.
connect
()
self
.
pgdb
.
connect
()
'''
'''
从本地库中查找向量对应的文本段落,封装成Document返回
从本地库中查找向量对应的文本段落,封装成Document返回
'''
'''
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
if
not
self
.
pgdb
.
conn
:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
self
.
__sub_init__
()
...
@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin):
...
@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin):
return
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
return
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
else
:
else
:
return
Document
()
return
Document
()
'''
'''
从本地库中删除向量对应的文本,批量删除
从本地库中删除向量对应的文本,批量删除
'''
'''
def
delete
(
self
,
ids
:
List
)
->
None
:
def
delete
(
self
,
ids
:
List
)
->
None
:
if
not
self
.
pgdb
.
conn
:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
self
.
__sub_init__
()
pids
=
[]
pids
=
[]
for
i
d
in
ids
:
for
i
tem
in
ids
:
anwser
=
self
.
VEC_TXT
.
search
(
i
d
)
anwser
=
self
.
VEC_TXT
.
search
(
i
tem
)
pids
.
append
(
anwser
[
0
])
pids
.
append
(
anwser
[
0
])
self
.
VEC_TXT
.
delete
(
ids
)
self
.
VEC_TXT
.
delete
(
ids
)
self
.
TXT_DOC
.
delete
(
pids
)
self
.
TXT_DOC
.
delete
(
pids
)
'''
'''
向本地库添加向量和文本信息
向本地库添加向量和文本信息
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
'''
'''
def
add
(
self
,
texts
:
Dict
[
str
,
Document
])
->
None
:
def
add
(
self
,
texts
:
Dict
[
str
,
Document
])
->
None
:
# for vec,doc in texts.items():
# for vec,doc in texts.items():
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
if
not
self
.
pgdb
.
conn
:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
self
.
__sub_init__
()
paragraph_hashs
=
[]
#
hash,text
paragraph_hashs
=
[]
#
hash,text
paragraph_txts
=
[]
paragraph_txts
=
[]
vec_inserts
=
[]
vec_inserts
=
[]
for
vec
,
doc
in
texts
.
items
():
for
vec
,
doc
in
texts
.
items
():
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
print
(
txt_hash
)
print
(
txt_hash
)
vec_inserts
.
append
((
vec
,
doc
.
page_content
,
txt_hash
))
vec_inserts
.
append
((
vec
,
doc
.
page_content
,
txt_hash
))
if
txt_hash
not
in
paragraph_hashs
:
if
txt_hash
not
in
paragraph_hashs
:
paragraph_hashs
.
append
(
txt_hash
)
paragraph_hashs
.
append
(
txt_hash
)
paragraph
=
doc
.
metadata
[
"paragraph"
]
paragraph
=
doc
.
metadata
[
"paragraph"
]
doc
.
metadata
.
pop
(
"paragraph"
)
doc
.
metadata
.
pop
(
"paragraph"
)
paragraph_txts
.
append
((
txt_hash
,
paragraph
,
json
.
dumps
(
doc
.
metadata
,
ensure_ascii
=
False
)))
paragraph_txts
.
append
((
txt_hash
,
paragraph
,
json
.
dumps
(
doc
.
metadata
,
ensure_ascii
=
False
)))
# print(paragraph_txts)
# print(paragraph_txts)
self
.
TXT_DOC
.
insert
(
paragraph_txts
)
self
.
TXT_DOC
.
insert
(
paragraph_txts
)
self
.
VEC_TXT
.
insert
(
vec_inserts
)
self
.
VEC_TXT
.
insert
(
vec_inserts
)
...
@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin):
...
@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin):
class
InMemorySecondaryDocstore
(
Docstore
,
AddableMixin
):
class
InMemorySecondaryDocstore
(
Docstore
,
AddableMixin
):
"""Simple in memory docstore in the form of a dict."""
"""Simple in memory docstore in the form of a dict."""
def
__init__
(
self
,
_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
,
_sec_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
):
def
__init__
(
self
,
_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
,
_sec_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
):
"""Initialize with dict."""
"""Initialize with dict."""
self
.
_dict
=
_dict
if
_dict
is
not
None
else
{}
self
.
_dict
=
_dict
if
_dict
is
not
None
else
{}
self
.
_sec_dict
=
_sec_dict
if
_sec_dict
is
not
None
else
{}
self
.
_sec_dict
=
_sec_dict
if
_sec_dict
is
not
None
else
{}
...
@@ -123,19 +132,19 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
...
@@ -123,19 +132,19 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if
overlapping
:
if
overlapping
:
raise
ValueError
(
f
"Tried to add ids that already exist: {overlapping}"
)
raise
ValueError
(
f
"Tried to add ids that already exist: {overlapping}"
)
self
.
_dict
=
{
**
self
.
_dict
,
**
texts
}
self
.
_dict
=
{
**
self
.
_dict
,
**
texts
}
dict1
=
{}
dict1
=
{}
dict_sec
=
{}
dict_sec
=
{}
for
vec
,
doc
in
texts
.
items
():
for
vec
,
doc
in
texts
.
items
():
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
metadata
=
doc
.
metadata
metadata
=
doc
.
metadata
paragraph
=
metadata
.
pop
(
'paragraph'
)
paragraph
=
metadata
.
pop
(
'paragraph'
)
# metadata.update({"paragraph_id":txt_hash})
# metadata.update({"paragraph_id":txt_hash})
metadata
[
'paragraph_id'
]
=
txt_hash
metadata
[
'paragraph_id'
]
=
txt_hash
dict_sec
[
txt_hash
]
=
Document
(
page_content
=
paragraph
,
metadata
=
metadata
)
dict_sec
[
txt_hash
]
=
Document
(
page_content
=
paragraph
,
metadata
=
metadata
)
dict1
[
vec
]
=
Document
(
page_content
=
doc
.
page_content
,
metadata
=
{
'paragraph_id'
:
txt_hash
})
dict1
[
vec
]
=
Document
(
page_content
=
doc
.
page_content
,
metadata
=
{
'paragraph_id'
:
txt_hash
})
self
.
_dict
=
{
**
self
.
_dict
,
**
dict1
}
self
.
_dict
=
{
**
self
.
_dict
,
**
dict1
}
self
.
_sec_dict
=
{
**
self
.
_sec_dict
,
**
dict_sec
}
self
.
_sec_dict
=
{
**
self
.
_sec_dict
,
**
dict_sec
}
def
delete
(
self
,
ids
:
List
)
->
None
:
def
delete
(
self
,
ids
:
List
)
->
None
:
"""Deleting IDs from in memory dictionary."""
"""Deleting IDs from in memory dictionary."""
...
@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
...
@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if
not
overlapping
:
if
not
overlapping
:
raise
ValueError
(
f
"Tried to delete ids that does not exist: {ids}"
)
raise
ValueError
(
f
"Tried to delete ids that does not exist: {ids}"
)
for
_id
in
ids
:
for
_id
in
ids
:
self
.
_sec_dict
.
pop
(
self
.
_dict
[
id
]
.
metadata
[
'paragraph_id'
])
self
.
_sec_dict
.
pop
(
self
.
_dict
[
_
id
]
.
metadata
[
'paragraph_id'
])
self
.
_dict
.
pop
(
_id
)
self
.
_dict
.
pop
(
_id
)
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
...
@@ -159,4 +168,4 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
...
@@ -159,4 +168,4 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
return
f
"ID {search} not found."
return
f
"ID {search} not found."
else
:
else
:
print
(
self
.
_dict
[
search
]
.
page_content
)
print
(
self
.
_dict
[
search
]
.
page_content
)
return
self
.
_sec_dict
[
self
.
_dict
[
search
]
.
metadata
[
'paragraph_id'
]]
return
self
.
_sec_dict
[
self
.
_dict
[
search
]
.
metadata
[
'paragraph_id'
]]
\ No newline at end of file
src/pgdb/knowledge/similarity.py
View file @
493cdd59
This diff is collapsed.
Click to expand it.
src/pgdb/knowledge/txt_doc_table.py
View file @
493cdd59
from
.k_db
import
PostgresDB
from
.k_db
import
PostgresDB
# paragraph_id BIGSERIAL primary key,
# paragraph_id BIGSERIAL primary key,
TABLE_TXT_DOC
=
"""
TABLE_TXT_DOC
=
"""
create table txt_doc (
create table txt_doc (
hash varchar(40) primary key,
hash varchar(40) primary key,
...
@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """
...
@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
"""
"""
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
class
TxtDoc
:
class
TxtDoc
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
...
@@ -21,19 +24,20 @@ class TxtDoc:
...
@@ -21,19 +24,20 @@ class TxtDoc:
args
=
[]
args
=
[]
for
value
in
texts
:
for
value
in
texts
:
value
=
list
(
value
)
value
=
list
(
value
)
query
+=
"(
%
s,
%
s,
%
s),"
query
+=
"(
%
s,
%
s,
%
s),"
args
.
extend
(
value
)
args
.
extend
(
value
)
query
=
query
[:
len
(
query
)
-
1
]
query
=
query
[:
len
(
query
)
-
1
]
query
+=
f
"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
query
+=
f
"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
self
.
db
.
execute_args
(
query
,
args
)
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
def
delete
(
self
,
ids
):
for
i
d
in
ids
:
for
i
tem
in
ids
:
query
=
f
"delete FROM txt_doc WHERE hash =
%
s"
%
(
id
)
query
=
f
"delete FROM txt_doc WHERE hash =
%
s"
%
item
self
.
db
.
execute
(
query
)
self
.
db
.
execute
(
query
)
def
search
(
self
,
id
):
def
search
(
self
,
item
):
query
=
"SELECT text,matadate FROM txt_doc WHERE hash =
%
s"
query
=
"SELECT text,matadate FROM txt_doc WHERE hash =
%
s"
self
.
db
.
execute_args
(
query
,
[
id
])
self
.
db
.
execute_args
(
query
,
[
item
])
answer
=
self
.
db
.
fetchall
()
answer
=
self
.
db
.
fetchall
()
if
len
(
answer
)
>
0
:
if
len
(
answer
)
>
0
:
return
answer
[
0
]
return
answer
[
0
]
...
@@ -60,4 +64,3 @@ class TxtDoc:
...
@@ -60,4 +64,3 @@ class TxtDoc:
query
=
"DROP TABLE txt_doc"
query
=
"DROP TABLE txt_doc"
self
.
db
.
format
(
query
)
self
.
db
.
format
(
query
)
print
(
"drop table txt_doc ok"
)
print
(
"drop table txt_doc ok"
)
src/pgdb/knowledge/vec_txt_table.py
View file @
493cdd59
from
.k_db
import
PostgresDB
from
.k_db
import
PostgresDB
TABLE_VEC_TXT
=
"""
TABLE_VEC_TXT
=
"""
CREATE TABLE vec_txt (
CREATE TABLE vec_txt (
vector_id varchar(36) PRIMARY KEY,
vector_id varchar(36) PRIMARY KEY,
...
@@ -6,7 +7,9 @@ CREATE TABLE vec_txt (
...
@@ -6,7 +7,9 @@ CREATE TABLE vec_txt (
paragraph_id varchar(40) not null
paragraph_id varchar(40) not null
)
)
"""
"""
#025a9bee-2eb2-47f5-9722-525e05a0442b
# 025a9bee-2eb2-47f5-9722-525e05a0442b
class
TxtVector
:
class
TxtVector
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
self
.
db
=
db
self
.
db
=
db
...
@@ -16,19 +19,21 @@ class TxtVector:
...
@@ -16,19 +19,21 @@ class TxtVector:
args
=
[]
args
=
[]
for
value
in
vectors
:
for
value
in
vectors
:
value
=
list
(
value
)
value
=
list
(
value
)
query
+=
"(
%
s,
%
s,
%
s),"
query
+=
"(
%
s,
%
s,
%
s),"
args
.
extend
(
value
)
args
.
extend
(
value
)
query
=
query
[:
len
(
query
)
-
1
]
query
=
query
[:
len
(
query
)
-
1
]
query
+=
f
"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
query
+=
f
"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
# query += ";"
# query += ";"
self
.
db
.
execute_args
(
query
,
args
)
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
id
in
ids
:
def
delete
(
self
,
ids
):
query
=
f
"delete FROM vec_txt WHERE vector_id = '
%
s'"
%
(
id
,)
for
item
in
ids
:
query
=
f
"delete FROM vec_txt WHERE vector_id = '
%
s'"
%
(
item
,)
self
.
db
.
execute
(
query
)
self
.
db
.
execute
(
query
)
def
search
(
self
,
search
:
str
):
def
search
(
self
,
search
:
str
):
query
=
f
"SELECT paragraph_id,text FROM vec_txt WHERE vector_id =
%
s"
query
=
f
"SELECT paragraph_id,text FROM vec_txt WHERE vector_id =
%
s"
self
.
db
.
execute_args
(
query
,[
search
])
self
.
db
.
execute_args
(
query
,
[
search
])
answer
=
self
.
db
.
fetchall
()
answer
=
self
.
db
.
fetchall
()
print
(
answer
)
print
(
answer
)
return
answer
[
0
]
return
answer
[
0
]
...
@@ -48,4 +53,4 @@ class TxtVector:
...
@@ -48,4 +53,4 @@ class TxtVector:
if
exists
:
if
exists
:
query
=
"DROP TABLE vec_txt"
query
=
"DROP TABLE vec_txt"
self
.
db
.
format
(
query
)
self
.
db
.
format
(
query
)
print
(
"drop table vec_txt ok"
)
print
(
"drop table vec_txt ok"
)
\ No newline at end of file
test/chat_table_test.py
View file @
493cdd59
import
sys
import
sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.pgdb.chat.chat_table
import
Chat
from
src.pgdb.chat.chat_table
import
Chat
from
src.pgdb.chat.c_user_table
import
CUser
from
src.pgdb.chat.c_user_table
import
CUser
from
src.pgdb.chat.turn_qa_table
import
TurnQa
from
src.pgdb.chat.turn_qa_table
import
TurnQa
"""测试会话相关数据可的连接"""
"""测试会话相关数据可的连接"""
c_db
=
UPostgresDB
(
host
=
"localhost"
,
database
=
"laechat"
,
user
=
"postgres"
,
password
=
"chenzl"
,
port
=
5432
)
chat
=
Chat
(
db
=
c_db
)
c_user
=
CUser
(
db
=
c_db
)
turn_qa
=
TurnQa
(
db
=
c_db
)
chat
.
create_table
()
c_user
.
create_table
()
turn_qa
.
create_table
()
# chat_id, user_id, info, deleted
def
test
():
chat
.
insert
([
"3333"
,
"1111"
,
"没有info"
,
0
])
c_db
=
UPostgresDB
(
host
=
"localhost"
,
database
=
"laechat"
,
user
=
"postgres"
,
password
=
"chenzl"
,
port
=
5432
)
chat
=
Chat
(
db
=
c_db
)
c_user
=
CUser
(
db
=
c_db
)
turn_qa
=
TurnQa
(
db
=
c_db
)
chat
.
create_table
()
c_user
.
create_table
()
turn_qa
.
create_table
()
# chat_id, user_id, info, deleted
chat
.
insert
([
"3333"
,
"1111"
,
"没有info"
,
0
])
# user_id, account, password
c_user
.
insert
([
"111"
,
"zhangsan"
,
"111111"
])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa
.
insert
([
"222"
,
"1111"
,
"nihao"
,
"nihao"
,
1
,
0
])
# user_id, account, password
c_user
.
insert
([
"111"
,
"zhangsan"
,
"111111"
])
# turn_id, chat_id, question, answer, turn_number, is_last
if
__name__
==
"main"
:
turn_qa
.
insert
([
"222"
,
"1111"
,
"nihao"
,
"nihao"
,
1
,
0
])
test
()
\ No newline at end of file
test/k_store_test.py
View file @
493cdd59
import
sys
import
sys
sys
.
path
.
append
(
"../"
)
import
time
sys
.
path
.
append
(
"../"
)
from
src.loader.load
import
loads_path
,
loads
from
src.loader.load
import
loads_path
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.config.consts
import
(
from
src.config.consts
import
(
VEC_DB_DBNAME
,
VEC_DB_DBNAME
,
...
@@ -18,24 +18,27 @@ from src.config.consts import (
...
@@ -18,24 +18,27 @@ from src.config.consts import (
from
src.loader.callback
import
BaseCallback
from
src.loader.callback
import
BaseCallback
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class
localCallback
(
BaseCallback
):
class
localCallback
(
BaseCallback
):
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
if
len
(
title
+
content
)
==
0
:
if
len
(
title
+
content
)
==
0
:
return
True
return
True
return
(
len
(
title
+
content
)
/
(
len
(
title
.
splitlines
())
+
len
(
content
.
splitlines
()))
<
20
)
or
"思考题"
in
title
return
(
len
(
title
+
content
)
/
(
len
(
title
.
splitlines
())
+
len
(
content
.
splitlines
()))
<
20
)
or
"思考题"
in
title
"""测试资料入库(pgsql和faiss)"""
"""测试资料入库(pgsql和faiss)"""
def
test_faiss_from_dir
():
def
test_faiss_from_dir
():
vecstore_faiss
=
VectorStore_FAISS
(
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
show_number
=
SIMILARITY_SHOW_NUMBER
,
"password"
:
VEC_DB_PASSWORD
},
reset
=
True
)
show_number
=
SIMILARITY_SHOW_NUMBER
,
docs
=
loads_path
(
KNOWLEDGE_PATH
,
mode
=
"paged"
,
sentence_size
=
512
,
callbacks
=
[
localCallback
()])
reset
=
True
)
docs
=
loads_path
(
KNOWLEDGE_PATH
,
mode
=
"paged"
,
sentence_size
=
512
,
callbacks
=
[
localCallback
()])
print
(
len
(
docs
))
print
(
len
(
docs
))
last_doc
=
None
last_doc
=
None
docs1
=
[]
docs1
=
[]
...
@@ -45,7 +48,8 @@ def test_faiss_from_dir():
...
@@ -45,7 +48,8 @@ def test_faiss_from_dir():
continue
continue
if
"font-size"
not
in
doc
.
metadata
or
"page_number"
not
in
doc
.
metadata
:
if
"font-size"
not
in
doc
.
metadata
or
"page_number"
not
in
doc
.
metadata
:
continue
continue
if
doc
.
metadata
[
"font-size"
]
==
last_doc
.
metadata
[
"font-size"
]
and
doc
.
metadata
[
"page_number"
]
==
last_doc
.
metadata
[
"page_number"
]
and
len
(
doc
.
page_content
)
+
len
(
last_doc
.
page_content
)
<
512
/
4
*
3
:
if
doc
.
metadata
[
"font-size"
]
==
last_doc
.
metadata
[
"font-size"
]
and
doc
.
metadata
[
"page_number"
]
==
\
last_doc
.
metadata
[
"page_number"
]
and
len
(
doc
.
page_content
)
+
len
(
last_doc
.
page_content
)
<
512
/
4
*
3
:
last_doc
.
page_content
+=
doc
.
page_content
last_doc
.
page_content
+=
doc
.
page_content
else
:
else
:
docs1
.
append
(
last_doc
)
docs1
.
append
(
last_doc
)
...
@@ -56,22 +60,26 @@ def test_faiss_from_dir():
...
@@ -56,22 +60,26 @@ def test_faiss_from_dir():
print
(
len
(
docs
))
print
(
len
(
docs
))
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
for
i
in
range
(
0
,
len
(
docs
),
300
):
for
i
in
range
(
0
,
len
(
docs
),
300
):
vecstore_faiss
.
_add_documents
(
docs
[
i
:
i
+
300
if
i
+
300
<
len
(
docs
)
else
len
(
docs
)],
need_split
=
True
)
vecstore_faiss
.
_add_documents
(
docs
[
i
:
i
+
300
if
i
+
300
<
len
(
docs
)
else
len
(
docs
)],
need_split
=
True
)
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
vecstore_faiss
.
_save_local
()
vecstore_faiss
.
_save_local
()
"""测试faiss向量数据库查询结果"""
"""测试faiss向量数据库查询结果"""
def
test_faiss_load
():
def
test_faiss_load
():
vecstore_faiss
=
VectorStore_FAISS
(
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
show_number
=
SIMILARITY_SHOW_NUMBER
,
"password"
:
VEC_DB_PASSWORD
},
reset
=
False
)
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
print
(
vecstore_faiss
.
_join_document
(
vecstore_faiss
.
get_text_similarity
(
"征信业务有什么情况"
)))
print
(
vecstore_faiss
.
_join_document
(
vecstore_faiss
.
get_text_similarity
(
"征信业务有什么情况"
)))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# test_faiss_from_dir()
# test_faiss_from_dir()
test_faiss_load
()
test_faiss_load
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment