Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
493cdd59
Commit
493cdd59
authored
Apr 25, 2024
by
陈正乐
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
代码格式化
parent
9d8ee0af
Show whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
496 additions
and
397 deletions
+496
-397
consts.py
src/config/consts.py
+8
-8
__init__.py
src/llm/__init__.py
+0
-0
baichuan.py
src/llm/baichuan.py
+10
-18
chatglm.py
src/llm/chatglm.py
+39
-36
chatglm_openapi.py
src/llm/chatglm_openapi.py
+9
-8
ernie.py
src/llm/ernie.py
+15
-21
ernie_sdk.py
src/llm/ernie_sdk.py
+11
-4
ernie_with_sdk.py
src/llm/ernie_with_sdk.py
+14
-15
loader.py
src/llm/loader.py
+10
-9
spark.py
src/llm/spark.py
+18
-25
wrapper.py
src/llm/wrapper.py
+5
-9
callback.py
src/loader/callback.py
+2
-1
chinese_text_splitter.py
src/loader/chinese_text_splitter.py
+2
-2
config.py
src/loader/config.py
+0
-0
load.py
src/loader/load.py
+123
-88
zh_title_enhance.py
src/loader/zh_title_enhance.py
+1
-1
c_db.py
src/pgdb/chat/c_db.py
+6
-5
c_user_table.py
src/pgdb/chat/c_user_table.py
+3
-1
chat_table.py
src/pgdb/chat/chat_table.py
+3
-1
turn_qa_table.py
src/pgdb/chat/turn_qa_table.py
+3
-1
callback.py
src/pgdb/knowledge/callback.py
+18
-14
k_db.py
src/pgdb/knowledge/k_db.py
+7
-4
pgsqldocstore.py
src/pgdb/knowledge/pgsqldocstore.py
+37
-28
similarity.py
src/pgdb/knowledge/similarity.py
+86
-54
txt_doc_table.py
src/pgdb/knowledge/txt_doc_table.py
+13
-10
vec_txt_table.py
src/pgdb/knowledge/vec_txt_table.py
+13
-8
chat_table_test.py
test/chat_table_test.py
+21
-15
k_store_test.py
test/k_store_test.py
+19
-11
No files found.
src/config/consts.py
View file @
493cdd59
...
@@ -2,19 +2,19 @@
...
@@ -2,19 +2,19 @@
# 资料存储数据库配置
# 资料存储数据库配置
# =============================
# =============================
VEC_DB_HOST
=
'localhost'
VEC_DB_HOST
=
'localhost'
VEC_DB_DBNAME
=
'lae'
VEC_DB_DBNAME
=
'lae'
VEC_DB_USER
=
'postgres'
VEC_DB_USER
=
'postgres'
VEC_DB_PASSWORD
=
'chenzl'
VEC_DB_PASSWORD
=
'chenzl'
VEC_DB_PORT
=
'5432'
VEC_DB_PORT
=
'5432'
# =============================
# =============================
# 聊天相关数据库配置
# 聊天相关数据库配置
# =============================
# =============================
CHAT_DB_HOST
=
'localhost'
CHAT_DB_HOST
=
'localhost'
CHAT_DB_DBNAME
=
'laechat'
CHAT_DB_DBNAME
=
'laechat'
CHAT_DB_USER
=
'postgres'
CHAT_DB_USER
=
'postgres'
CHAT_DB_PASSWORD
=
'chenzl'
CHAT_DB_PASSWORD
=
'chenzl'
CHAT_DB_PORT
=
'5432'
CHAT_DB_PORT
=
'5432'
# =============================
# =============================
# 向量化模型路径配置
# 向量化模型路径配置
...
...
src/llm/__init__.py
View file @
493cdd59
src/llm/baichuan.py
View file @
493cdd59
import
os
import
os
from
typing
import
Dict
,
Optional
,
List
from
typing
import
Dict
,
Optional
,
List
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
,
AutoModelForCausalLM
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
,
AutoModelForCausalLM
from
transformers.generation.utils
import
GenerationConfig
from
transformers.generation.utils
import
GenerationConfig
from
pydantic
import
root_validator
from
pydantic
import
root_validator
class
BaichuanLLM
(
LLM
):
class
BaichuanLLM
(
LLM
):
model_name
:
str
=
"baichuan-inc/Baichuan-13B-Chat"
model_name
:
str
=
"baichuan-inc/Baichuan-13B-Chat"
quantization_bit
:
Optional
[
int
]
=
None
quantization_bit
:
Optional
[
int
]
=
None
...
@@ -19,17 +17,16 @@ class BaichuanLLM(LLM):
...
@@ -19,17 +17,16 @@ class BaichuanLLM(LLM):
tokenizer
:
AutoTokenizer
=
None
tokenizer
:
AutoTokenizer
=
None
model
:
AutoModel
=
None
model
:
AutoModel
=
None
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
return
"chatglm_local"
return
"chatglm_local"
@root_validator
()
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
if
not
values
[
"model_name"
]:
if
not
values
[
"model_name"
]:
raise
ValueError
(
"No model name provided."
)
raise
ValueError
(
"No model name provided."
)
model_name
=
values
[
"model_name"
]
model_name
=
values
[
"model_name"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
model_name
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
...
@@ -43,7 +40,7 @@ class BaichuanLLM(LLM):
...
@@ -43,7 +40,7 @@ class BaichuanLLM(LLM):
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
.
cuda
()
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
.
cuda
()
else
:
else
:
model
=
model
.
half
()
.
cuda
()
model
=
model
.
half
()
.
cuda
()
model
=
model
.
eval
()
model
=
model
.
eval
()
...
@@ -51,14 +48,9 @@ class BaichuanLLM(LLM):
...
@@ -51,14 +48,9 @@ class BaichuanLLM(LLM):
values
[
"model"
]
=
model
values
[
"model"
]
=
model
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
message
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
stop
:
Optional
[
List
[
str
]]
=
None
,
resp
=
self
.
model
.
chat
(
self
.
tokenizer
,
message
)
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
message
=
[]
message
.
append
({
"role"
:
"user"
,
"content"
:
prompt
})
resp
=
self
.
model
.
chat
(
self
.
tokenizer
,
message
)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return
resp
return
resp
src/llm/chatglm.py
View file @
493cdd59
import
os
import
os
import
requests
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
pydantic
import
root_validator
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
langchain
import
langchain
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
aiohttp
import
aiohttp
import
asyncio
import
asyncio
# 启动llm的缓存
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
# langchain.llm_cache = InMemoryCache()
...
@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM):
...
@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM):
tokenizer
:
AutoTokenizer
=
None
tokenizer
:
AutoTokenizer
=
None
model
:
AutoModel
=
None
model
:
AutoModel
=
None
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
return
"chatglm_local"
return
"chatglm_local"
# @root_validator()
# @root_validator()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
@staticmethod
def
validate_environment
(
values
:
Dict
)
->
Dict
:
if
not
values
[
"model_name"
]:
if
not
values
[
"model_name"
]:
raise
ValueError
(
"No model name provided."
)
raise
ValueError
(
"No model name provided."
)
model_name
=
values
[
"model_name"
]
model_name
=
values
[
"model_name"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if
values
[
"pre_seq_len"
]:
if
values
[
"pre_seq_len"
]:
...
@@ -72,18 +73,14 @@ class ChatGLMLocLLM(LLM):
...
@@ -72,18 +73,14 @@ class ChatGLMLocLLM(LLM):
values
[
"model"
]
=
model
values
[
"model"
]
=
model
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return
resp
return
resp
class
ChatGLMSerLLM
(
LLM
):
class
ChatGLMSerLLM
(
LLM
):
# 模型服务url
# 模型服务url
url
:
str
=
"http://127.0.0.1:8000"
url
:
str
=
"http://127.0.0.1:8000"
chat_history
:
dict
=
[]
chat_history
:
dict
=
[]
...
@@ -95,7 +92,7 @@ class ChatGLMSerLLM(LLM):
...
@@ -95,7 +92,7 @@ class ChatGLMSerLLM(LLM):
return
"chatglm3-6b"
return
"chatglm3-6b"
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
resp
=
self
.
_post
(
url
=
self
.
url
+
"/tokens"
,
query
=
self
.
_construct_query
(
text
))
resp
=
self
.
_post
(
url
=
self
.
url
+
"/tokens"
,
query
=
self
.
_construct_query
(
text
))
if
resp
.
status_code
==
200
:
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
resp_json
=
resp
.
json
()
predictions
=
resp_json
[
'response'
]
predictions
=
resp_json
[
'response'
]
...
@@ -104,27 +101,28 @@ class ChatGLMSerLLM(LLM):
...
@@ -104,27 +101,28 @@ class ChatGLMSerLLM(LLM):
else
:
else
:
return
len
(
text
)
return
len
(
text
)
@staticmethod
def
convert_data
(
self
,
data
):
def
convert_data
(
data
):
result
=
[]
result
=
[]
for
item
in
data
:
for
item
in
data
:
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
return
result
return
result
def
_construct_query
(
self
,
prompt
:
str
,
temperature
=
0.95
)
->
Dict
:
def
_construct_query
(
self
,
prompt
:
str
,
temperature
=
0.95
)
->
Dict
:
"""构造请求体
"""构造请求体
"""
"""
# self.chat_history.append({"role": "user", "content": prompt})
# self.chat_history.append({"role": "user", "content": prompt})
query
=
{
query
=
{
"prompt"
:
prompt
,
"prompt"
:
prompt
,
"history"
:
self
.
chat_history
,
"history"
:
self
.
chat_history
,
"max_length"
:
4096
,
"max_length"
:
4096
,
"top_p"
:
0.7
,
"top_p"
:
0.7
,
"temperature"
:
temperature
"temperature"
:
temperature
}
}
return
query
return
query
@classmethod
@classmethod
def
_post
(
self
,
url
:
str
,
def
_post
(
cls
,
url
:
str
,
query
:
Dict
)
->
Any
:
query
:
Dict
)
->
Any
:
"""POST请求
"""POST请求
"""
"""
...
@@ -135,46 +133,50 @@ class ChatGLMSerLLM(LLM):
...
@@ -135,46 +133,50 @@ class ChatGLMSerLLM(LLM):
headers
=
_headers
,
headers
=
_headers
,
timeout
=
300
)
timeout
=
300
)
return
resp
return
resp
async
def
_post_stream
(
self
,
url
:
str
,
@staticmethod
async
def
_post_stream
(
url
:
str
,
query
:
Dict
,
query
:
Dict
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
"""POST请求
"""POST请求
"""
"""
_headers
=
{
"Content_Type"
:
"application/json"
}
_headers
=
{
"Content_Type"
:
"application/json"
}
async
with
aiohttp
.
ClientSession
()
as
sess
:
async
with
aiohttp
.
ClientSession
()
as
sess
:
async
with
sess
.
post
(
url
,
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
as
response
:
async
with
sess
.
post
(
url
,
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
as
response
:
if
response
.
status
==
200
:
if
response
.
status
==
200
:
if
stream
and
not
run_manager
:
if
stream
and
not
run_manager
:
print
(
'not callable'
)
print
(
'not callable'
)
if
run_manager
:
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_start
(
None
,
None
)
await
_callable
.
on_llm_start
(
None
,
None
)
async
for
chunk
in
response
.
content
.
iter_any
():
async
for
chunk
in
response
.
content
.
iter_any
():
# 处理每个块的数据
# 处理每个块的数据
if
chunk
and
run_manager
:
if
chunk
and
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
# print(chunk.decode("utf-8"),end="")
# print(chunk.decode("utf-8"),end="")
await
callable
.
on_llm_new_token
(
chunk
.
decode
(
"utf-8"
))
await
_
callable
.
on_llm_new_token
(
chunk
.
decode
(
"utf-8"
))
if
run_manager
:
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_end
(
None
)
await
_
callable
.
on_llm_end
(
None
)
else
:
else
:
raise
ValueError
(
f
'glm 请求异常,http code:{response.status}'
)
raise
ValueError
(
f
'glm 请求异常,http code:{response.status}'
)
def
_call
(
self
,
prompt
:
str
,
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
,
stream
=
False
,
**
kwargs
:
Any
)
->
str
:
**
kwargs
:
Any
)
->
str
:
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
# display("==============================")
# display("==============================")
# display(query)
# display(query)
# post
# post
if
stream
or
self
.
out_stream
:
if
stream
or
self
.
out_stream
:
async
def
_post_stream
():
async
def
_post_stream
():
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
stream
or
self
.
out_stream
)
query
=
query
,
run_manager
=
run_manager
,
stream
=
stream
or
self
.
out_stream
)
asyncio
.
run
(
_post_stream
())
asyncio
.
run
(
_post_stream
())
return
''
return
''
else
:
else
:
...
@@ -197,9 +199,10 @@ class ChatGLMSerLLM(LLM):
...
@@ -197,9 +199,10 @@ class ChatGLMSerLLM(LLM):
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
str
:
)
->
str
:
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
query
=
self
.
_construct_query
(
prompt
=
prompt
,
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
query
=
query
,
run_manager
=
run_manager
,
stream
=
self
.
out_stream
)
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
self
.
out_stream
)
return
''
return
''
@property
@property
...
...
src/llm/chatglm_openapi.py
View file @
493cdd59
import
os
import
os
import
requests
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
pydantic
import
root_validator
import
torch
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
langchain
import
langchain
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain_openai
import
OpenAI
from
langchain_openai
import
OpenAI
from
langchain.cache
import
InMemoryCache
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
class
ChatGLMSerLLM
(
OpenAI
):
class
ChatGLMSerLLM
(
OpenAI
):
def
get_token_ids
(
self
,
text
:
str
)
->
List
[
int
]:
def
get_token_ids
(
self
,
text
:
str
)
->
List
[
int
]:
...
@@ -20,9 +21,9 @@ class ChatGLMSerLLM(OpenAI):
...
@@ -20,9 +21,9 @@ class ChatGLMSerLLM(OpenAI):
## 发起http请求,获取token_ids
## 发起http请求,获取token_ids
url
=
f
"{self.openai_api_base}/num_tokens"
url
=
f
"{self.openai_api_base}/num_tokens"
query
=
{
"prompt"
:
text
,
"model"
:
self
.
model_name
}
query
=
{
"prompt"
:
text
,
"model"
:
self
.
model_name
}
_headers
=
{
"Content_Type"
:
"application/json"
,
"Authorization"
:
"chatglm "
+
self
.
openai_api_key
}
_headers
=
{
"Content_Type"
:
"application/json"
,
"Authorization"
:
"chatglm "
+
self
.
openai_api_key
}
resp
=
self
.
_post
(
url
=
url
,
query
=
query
,
headers
=
_headers
)
resp
=
self
.
_post
(
url
=
url
,
query
=
query
,
headers
=
_headers
)
if
resp
.
status_code
==
200
:
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
resp_json
=
resp
.
json
()
print
(
resp_json
)
print
(
resp_json
)
...
@@ -32,8 +33,8 @@ class ChatGLMSerLLM(OpenAI):
...
@@ -32,8 +33,8 @@ class ChatGLMSerLLM(OpenAI):
return
[
len
(
text
)]
return
[
len
(
text
)]
@classmethod
@classmethod
def
_post
(
self
,
url
:
str
,
def
_post
(
cls
,
url
:
str
,
query
:
Dict
,
headers
:
Dict
)
->
Any
:
query
:
Dict
,
headers
:
Dict
)
->
Any
:
"""POST请求
"""POST请求
"""
"""
_headers
=
{
"Content_Type"
:
"application/json"
}
_headers
=
{
"Content_Type"
:
"application/json"
}
...
...
src/llm/ernie.py
View file @
493cdd59
...
@@ -2,7 +2,7 @@ import logging
...
@@ -2,7 +2,7 @@ import logging
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
langchain.utils
import
get_from_dict_or_env
from
langchain.utils
import
get_from_dict_or_env
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
...
@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m
...
@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
ModelType
(
Enum
):
class
ModelType
(
Enum
):
ERNIE
=
"ernie"
ERNIE
=
"ernie"
ERNIE_LITE
=
"ernie-lite"
ERNIE_LITE
=
"ernie-lite"
...
@@ -25,7 +26,7 @@ class ModelType(Enum):
...
@@ -25,7 +26,7 @@ class ModelType(Enum):
LLAMA2_13B
=
"llama2-13b"
LLAMA2_13B
=
"llama2-13b"
LLAMA2_70B
=
"llama2-70b"
LLAMA2_70B
=
"llama2-70b"
QFCN_LLAMA2_7B
=
"qfcn-llama2-7b"
QFCN_LLAMA2_7B
=
"qfcn-llama2-7b"
BLOOMZ_7B
=
"bloomz-7b"
BLOOMZ_7B
=
"bloomz-7b"
MODEL_SERVICE_BASE_URL
=
"https://aip.baidubce.com/rpc/2.0/"
MODEL_SERVICE_BASE_URL
=
"https://aip.baidubce.com/rpc/2.0/"
...
@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = {
...
@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = {
ModelType
.
BLOOMZ_7B
:
"ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
,
ModelType
.
BLOOMZ_7B
:
"ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
,
}
}
class
ErnieLLM
(
LLM
):
class
ErnieLLM
(
LLM
):
"""
"""
ErnieLLM is a LLM that uses Ernie to generate text.
ErnieLLM is a LLM that uses Ernie to generate text.
...
@@ -52,7 +54,7 @@ class ErnieLLM(LLM):
...
@@ -52,7 +54,7 @@ class ErnieLLM(LLM):
access_token
:
Optional
[
str
]
=
""
access_token
:
Optional
[
str
]
=
""
@root_validator
()
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
"""Validate the environment."""
# print(values)
# print(values)
model_name
=
ModelType
(
get_from_dict_or_env
(
values
,
"model_name"
,
"model_name"
,
str
(
ModelType
.
ERNIE
)))
model_name
=
ModelType
(
get_from_dict_or_env
(
values
,
"model_name"
,
"model_name"
,
str
(
ModelType
.
ERNIE
)))
...
@@ -65,14 +67,10 @@ class ErnieLLM(LLM):
...
@@ -65,14 +67,10 @@ class ErnieLLM(LLM):
values
[
"access_token"
]
=
access_token
values
[
"access_token"
]
=
access_token
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
request
=
CompletionRequest
(
messages
=
[
Message
(
"user"
,
prompt
)])
request
=
CompletionRequest
(
messages
=
[
Message
(
"user"
,
prompt
)])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
or
""
,
request
)
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
or
""
,
request
)
try
:
try
:
# 你的代码
# 你的代码
...
@@ -81,10 +79,9 @@ class ErnieLLM(LLM):
...
@@ -81,10 +79,9 @@ class ErnieLLM(LLM):
return
response
return
response
except
Exception
as
e
:
except
Exception
as
e
:
# 处理异常
# 处理异常
print
(
"exception:"
,
e
)
print
(
"exception:"
,
e
)
return
e
.
__str__
()
return
e
.
__str__
()
@property
@property
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
"""Return type of llm."""
...
@@ -95,9 +92,10 @@ class ErnieLLM(LLM):
...
@@ -95,9 +92,10 @@ class ErnieLLM(LLM):
# "name": "ernie",
# "name": "ernie",
# }
# }
def
_get_model_service_url
(
model_name
)
->
str
:
def
_get_model_service_url
(
model_name
)
->
str
:
# print("_get_model_service_url model_name: ",model_name)
# print("_get_model_service_url model_name: ",model_name)
return
MODEL_SERVICE_BASE_URL
+
MODEL_SERVICE_Suffix
[
model_name
]
return
MODEL_SERVICE_BASE_URL
+
MODEL_SERVICE_Suffix
[
model_name
]
class
ErnieChat
(
LLM
):
class
ErnieChat
(
LLM
):
...
@@ -106,16 +104,12 @@ class ErnieChat(LLM):
...
@@ -106,16 +104,12 @@ class ErnieChat(LLM):
prefix_messages
:
List
=
Field
(
default_factory
=
list
)
prefix_messages
:
List
=
Field
(
default_factory
=
list
)
id
:
str
=
""
id
:
str
=
""
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
msg
=
user_message
(
prompt
)
msg
=
user_message
(
prompt
)
request
=
CompletionRequest
(
messages
=
self
.
prefix_messages
+
[
msg
])
request
=
CompletionRequest
(
messages
=
self
.
prefix_messages
+
[
msg
])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
,
request
)
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
,
request
)
try
:
try
:
# 你的代码
# 你的代码
response
=
bot
.
get_response
()
.
result
response
=
bot
.
get_response
()
.
result
...
...
src/llm/ernie_sdk.py
View file @
493cdd59
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
typing
import
List
from
typing
import
List
...
@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field
...
@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field
from
enum
import
Enum
from
enum
import
Enum
class
MessageRole
(
str
,
Enum
):
class
MessageRole
(
str
,
Enum
):
USER
=
"user"
USER
=
"user"
BOT
=
"assistant"
BOT
=
"assistant"
@dataclass
@dataclass
class
Message
:
class
Message
:
role
:
str
role
:
str
content
:
str
content
:
str
@dataclass
@dataclass
class
CompletionRequest
:
class
CompletionRequest
:
messages
:
List
[
Message
]
messages
:
List
[
Message
]
stream
:
bool
=
False
stream
:
bool
=
False
user
:
str
=
""
user
:
str
=
""
@dataclass
@dataclass
class
Usage
:
class
Usage
:
prompt_tokens
:
int
prompt_tokens
:
int
completion_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
total_tokens
:
int
@dataclass
@dataclass
class
CompletionResponse
:
class
CompletionResponse
:
id
:
str
id
:
str
...
@@ -42,12 +45,14 @@ class CompletionResponse:
...
@@ -42,12 +45,14 @@ class CompletionResponse:
is_safe
:
bool
=
False
is_safe
:
bool
=
False
is_truncated
:
bool
=
False
is_truncated
:
bool
=
False
class
ErrorResponse
(
BaseModel
):
class
ErrorResponse
(
BaseModel
):
error_code
:
int
=
Field
(
...
)
error_code
:
int
=
Field
(
...
)
error_msg
:
str
=
Field
(
...
)
error_msg
:
str
=
Field
(
...
)
id
:
str
=
Field
(
...
)
id
:
str
=
Field
(
...
)
class
ErnieBot
():
class
ErnieBot
:
url
:
str
url
:
str
access_token
:
str
access_token
:
str
request
:
CompletionRequest
request
:
CompletionRequest
...
@@ -65,7 +70,7 @@ class ErnieBot():
...
@@ -65,7 +70,7 @@ class ErnieBot():
headers
=
{
'Content-Type'
:
'application/json'
}
headers
=
{
'Content-Type'
:
'application/json'
}
params
=
{
'access_token'
:
self
.
access_token
}
params
=
{
'access_token'
:
self
.
access_token
}
request_dict
=
asdict
(
self
.
request
)
request_dict
=
asdict
(
self
.
request
)
response
=
requests
.
post
(
self
.
url
,
params
=
params
,
data
=
json
.
dumps
(
request_dict
),
headers
=
headers
)
response
=
requests
.
post
(
self
.
url
,
params
=
params
,
data
=
json
.
dumps
(
request_dict
),
headers
=
headers
)
# print(response.json())
# print(response.json())
try
:
try
:
return
CompletionResponse
(
**
response
.
json
())
return
CompletionResponse
(
**
response
.
json
())
...
@@ -73,8 +78,10 @@ class ErnieBot():
...
@@ -73,8 +78,10 @@ class ErnieBot():
print
(
e
)
print
(
e
)
raise
Exception
(
response
.
json
())
raise
Exception
(
response
.
json
())
def
user_message
(
prompt
:
str
)
->
Message
:
def
user_message
(
prompt
:
str
)
->
Message
:
return
Message
(
MessageRole
.
USER
,
prompt
)
return
Message
(
MessageRole
.
USER
,
prompt
)
def
bot_message
(
prompt
:
str
)
->
Message
:
def
bot_message
(
prompt
:
str
)
->
Message
:
return
Message
(
MessageRole
.
BOT
,
prompt
)
return
Message
(
MessageRole
.
BOT
,
prompt
)
src/llm/ernie_with_sdk.py
View file @
493cdd59
import
os
import
os
import
requests
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
pydantic
import
root_validator
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
qianfan
import
qianfan
from
qianfan
import
ChatCompletion
from
qianfan
import
ChatCompletion
# 启动llm的缓存
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
# langchain.llm_cache = InMemoryCache()
class
ChatERNIESerLLM
(
LLM
):
class
ChatERNIESerLLM
(
LLM
):
# 模型服务url
# 模型服务url
chat_completion
:
ChatCompletion
=
None
chat_completion
:
ChatCompletion
=
None
# url: str = "http://127.0.0.1:8000"
# url: str = "http://127.0.0.1:8000"
chat_history
:
dict
=
[]
chat_history
:
dict
=
[]
out_stream
:
bool
=
False
out_stream
:
bool
=
False
cache
:
bool
=
False
cache
:
bool
=
False
model_name
:
str
=
"ERNIE-Bot"
model_name
:
str
=
"ERNIE-Bot"
# def __init__(self):
# def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
...
@@ -32,20 +33,19 @@ class ChatERNIESerLLM(LLM):
...
@@ -32,20 +33,19 @@ class ChatERNIESerLLM(LLM):
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
return
len
(
text
)
return
len
(
text
)
def
convert_data
(
self
,
data
):
@staticmethod
def
convert_data
(
data
):
result
=
[]
result
=
[]
for
item
in
data
:
for
item
in
data
:
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
return
result
return
result
def
_call
(
self
,
prompt
:
str
,
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
,
stream
=
False
,
**
kwargs
:
Any
)
->
str
:
**
kwargs
:
Any
)
->
str
:
resp
=
self
.
chat_completion
.
do
(
model
=
self
.
model_name
,
messages
=
[{
resp
=
self
.
chat_completion
.
do
(
model
=
self
.
model_name
,
messages
=
[{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
prompt
"content"
:
prompt
}])
}])
...
@@ -59,11 +59,12 @@ class ChatERNIESerLLM(LLM):
...
@@ -59,11 +59,12 @@ class ChatERNIESerLLM(LLM):
stream
=
False
)
->
Any
:
stream
=
False
)
->
Any
:
"""POST请求
"""POST请求
"""
"""
async
for
r
in
await
self
.
chat_completion
.
ado
(
model
=
self
.
model_name
,
messages
=
[
query
],
stream
=
stream
):
async
for
r
in
await
self
.
chat_completion
.
ado
(
model
=
self
.
model_name
,
messages
=
[
query
],
stream
=
stream
):
assert
r
.
code
==
200
assert
r
.
code
==
200
if
run_manager
:
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_new_token
(
r
.
body
[
"result"
])
await
_callable
.
on_llm_new_token
(
r
.
body
[
"result"
])
async
def
_acall
(
async
def
_acall
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
...
@@ -74,6 +75,5 @@ class ChatERNIESerLLM(LLM):
...
@@ -74,6 +75,5 @@ class ChatERNIESerLLM(LLM):
await
self
.
_post_stream
(
query
=
{
await
self
.
_post_stream
(
query
=
{
"role"
:
"user"
,
"role"
:
"user"
,
"content"
:
prompt
"content"
:
prompt
},
stream
=
True
,
run_manager
=
run_manager
)
},
stream
=
True
,
run_manager
=
run_manager
)
return
''
return
''
\ No newline at end of file
src/llm/loader.py
View file @
493cdd59
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
transformers
import
AutoModel
,
AutoTokenizer
,
AutoConfig
,
DataCollatorForSeq2Seq
from
transformers
import
AutoModel
,
AutoTokenizer
,
AutoConfig
,
DataCollatorForSeq2Seq
from
peft
import
PeftModel
from
peft
import
PeftModel
class
ModelLoader
:
class
ModelLoader
:
def
__init__
(
self
,
model_name_or_path
,
pre_seq_len
=
0
,
prefix_projection
=
False
):
def
__init__
(
self
,
model_name_or_path
,
pre_seq_len
=
0
,
prefix_projection
=
False
):
self
.
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
self
.
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
...
@@ -27,14 +28,15 @@ class ModelLoader:
...
@@ -27,14 +28,15 @@ class ModelLoader:
def
collator
(
self
):
def
collator
(
self
):
return
DataCollatorForSeq2Seq
(
tokenizer
=
self
.
tokenizer
,
model
=
self
.
model
)
return
DataCollatorForSeq2Seq
(
tokenizer
=
self
.
tokenizer
,
model
=
self
.
model
)
def
load_lora
(
self
,
ckpt_path
,
name
=
"default"
):
def
load_lora
(
self
,
ckpt_path
,
name
=
"default"
):
#训练时节约GPU占用
#
训练时节约GPU占用
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
ckpt_path
,
adapter_name
=
name
)
_peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
ckpt_path
,
adapter_name
=
name
)
self
.
model
=
peft_loaded
.
merge_and_unload
()
self
.
model
=
_
peft_loaded
.
merge_and_unload
()
print
(
f
"Load LoRA model successfully!"
)
print
(
f
"Load LoRA model successfully!"
)
def
load_loras
(
self
,
ckpt_paths
,
name
=
"default"
):
def
load_loras
(
self
,
ckpt_paths
,
name
=
"default"
):
if
len
(
ckpt_paths
)
==
0
:
global
peft_loaded
if
len
(
ckpt_paths
)
==
0
:
return
return
first
=
True
first
=
True
for
name
,
path
in
ckpt_paths
.
items
():
for
name
,
path
in
ckpt_paths
.
items
():
...
@@ -43,11 +45,11 @@ class ModelLoader:
...
@@ -43,11 +45,11 @@ class ModelLoader:
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
path
,
adapter_name
=
name
)
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
path
,
adapter_name
=
name
)
first
=
False
first
=
False
else
:
else
:
peft_loaded
.
load_adapter
(
path
,
adapter_name
=
name
)
peft_loaded
.
load_adapter
(
path
,
adapter_name
=
name
)
peft_loaded
.
set_adapter
(
name
)
peft_loaded
.
set_adapter
(
name
)
self
.
model
=
peft_loaded
self
.
model
=
peft_loaded
def
load_prefix
(
self
,
ckpt_path
):
def
load_prefix
(
self
,
ckpt_path
):
prefix_state_dict
=
torch
.
load
(
os
.
path
.
join
(
ckpt_path
,
"pytorch_model.bin"
))
prefix_state_dict
=
torch
.
load
(
os
.
path
.
join
(
ckpt_path
,
"pytorch_model.bin"
))
new_prefix_state_dict
=
{}
new_prefix_state_dict
=
{}
for
k
,
v
in
prefix_state_dict
.
items
():
for
k
,
v
in
prefix_state_dict
.
items
():
...
@@ -56,4 +58,3 @@ class ModelLoader:
...
@@ -56,4 +58,3 @@ class ModelLoader:
self
.
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
self
.
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
self
.
model
.
transformer
.
prefix_encoder
.
float
()
self
.
model
.
transformer
.
prefix_encoder
.
float
()
print
(
f
"Load prefix model successfully!"
)
print
(
f
"Load prefix model successfully!"
)
src/llm/spark.py
View file @
493cdd59
...
@@ -2,7 +2,7 @@ import logging
...
@@ -2,7 +2,7 @@ import logging
import
os
import
os
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.schema
import
LLMResult
from
langchain.schema
import
LLMResult
from
langchain.utils
import
get_from_dict_or_env
from
langchain.utils
import
get_from_dict_or_env
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
...
@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI
...
@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
text
=
[]
text
=
[]
# length = 0
# length = 0
def
getText
(
role
,
content
):
def
getText
(
role
,
content
):
jsoncon
=
{}
jsoncon
=
{
"role"
:
role
,
"content"
:
content
}
jsoncon
[
"role"
]
=
role
jsoncon
[
"content"
]
=
content
text
.
append
(
jsoncon
)
text
.
append
(
jsoncon
)
return
text
return
text
def
getlength
(
text
):
def
getlength
(
text
):
length
=
0
length
=
0
for
content
in
text
:
for
content
in
text
:
...
@@ -36,10 +35,11 @@ def getlength(text):
...
@@ -36,10 +35,11 @@ def getlength(text):
length
+=
leng
length
+=
leng
return
length
return
length
def
checklen
(
text
):
while
(
getlength
(
text
)
>
8000
):
def
checklen
(
_text
):
del
text
[
0
]
while
getlength
(
_text
)
>
8000
:
return
text
del
_text
[
0
]
return
_text
class
SparkLLM
(
LLM
):
class
SparkLLM
(
LLM
):
...
@@ -68,7 +68,7 @@ class SparkLLM(LLM):
...
@@ -68,7 +68,7 @@ class SparkLLM(LLM):
)
)
@root_validator
()
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
"""Validate the environment."""
# print(values)
# print(values)
...
@@ -89,18 +89,14 @@ class SparkLLM(LLM):
...
@@ -89,18 +89,14 @@ class SparkLLM(LLM):
values
[
"api_key"
]
=
api_key
values
[
"api_key"
]
=
api_key
values
[
"api_secret"
]
=
api_secret
values
[
"api_secret"
]
=
api_secret
api
=
SparkAPI
(
appid
,
api_key
,
api_secret
,
version
)
api
=
SparkAPI
(
appid
,
api_key
,
api_secret
,
version
)
values
[
"api"
]
=
api
values
[
"api"
]
=
api
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
question
=
self
.
getText
(
"user"
,
prompt
)
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
question
=
self
.
getText
(
"user"
,
prompt
)
try
:
try
:
# 你的代码
# 你的代码
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
...
@@ -109,10 +105,10 @@ class SparkLLM(LLM):
...
@@ -109,10 +105,10 @@ class SparkLLM(LLM):
return
response
return
response
except
Exception
as
e
:
except
Exception
as
e
:
# 处理异常
# 处理异常
print
(
"exception:"
,
e
)
print
(
"exception:"
,
e
)
raise
e
raise
e
def
getText
(
self
,
role
,
content
):
def
getText
(
self
,
role
,
content
):
text
=
[]
text
=
[]
jsoncon
=
{}
jsoncon
=
{}
jsoncon
[
"role"
]
=
role
jsoncon
[
"role"
]
=
role
...
@@ -124,5 +120,3 @@ class SparkLLM(LLM):
...
@@ -124,5 +120,3 @@ class SparkLLM(LLM):
def
_llm_type
(
self
)
->
str
:
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
"""Return type of llm."""
return
"xinghuo"
return
"xinghuo"
\ No newline at end of file
src/llm/wrapper.py
View file @
493cdd59
from
langchain.llms.base
import
LLM
from
langchain.llms.base
import
LLM
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
pydantic
import
root_validator
from
pydantic
import
root_validator
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
class
WrapperLLM
(
LLM
):
class
WrapperLLM
(
LLM
):
tokenizer
:
PreTrainedTokenizer
=
None
tokenizer
:
PreTrainedTokenizer
=
None
model
:
PreTrainedModel
=
None
model
:
PreTrainedModel
=
None
@root_validator
()
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
"""Validate the environment."""
# print(values)
# print(values)
if
values
.
get
(
"model"
)
is
None
:
if
values
.
get
(
"model"
)
is
None
:
...
@@ -19,13 +19,9 @@ class WrapperLLM(LLM):
...
@@ -19,13 +19,9 @@ class WrapperLLM(LLM):
raise
ValueError
(
"No tokenizer provided."
)
raise
ValueError
(
"No tokenizer provided."
)
return
values
return
values
def
_call
(
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
self
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
prompt
:
str
,
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
return
resp
return
resp
@property
@property
...
...
src/loader/callback.py
View file @
493cdd59
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
class
BaseCallback
(
ABC
):
class
BaseCallback
(
ABC
):
@abstractmethod
@abstractmethod
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
#
return True舍弃当前段落
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
#
return True舍弃当前段落
pass
pass
src/loader/chinese_text_splitter.py
View file @
493cdd59
...
@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter):
...
@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter):
ele_id
=
ele1_ls
.
index
(
ele_ele1
)
ele_id
=
ele1_ls
.
index
(
ele_ele1
)
ele1_ls
=
ele1_ls
[:
ele_id
]
+
[
i
for
i
in
ele2_ls
if
i
]
+
ele1_ls
[
ele_id
+
1
:]
ele1_ls
=
ele1_ls
[:
ele_id
]
+
[
i
for
i
in
ele2_ls
if
i
]
+
ele1_ls
[
ele_id
+
1
:]
id
=
ls
.
index
(
ele
)
_
id
=
ls
.
index
(
ele
)
ls
=
ls
[:
id
]
+
[
i
for
i
in
ele1_ls
if
i
]
+
ls
[
id
+
1
:]
ls
=
ls
[:
_id
]
+
[
i
for
i
in
ele1_ls
if
i
]
+
ls
[
_
id
+
1
:]
return
ls
return
ls
src/loader/config.py
View file @
493cdd59
src/loader/load.py
View file @
493cdd59
import
os
,
copy
import
os
,
copy
from
langchain.document_loaders
import
UnstructuredFileLoader
,
TextLoader
,
CSVLoader
,
UnstructuredPDFLoader
,
UnstructuredWordDocumentLoader
,
PDFMinerPDFasHTMLLoader
from
langchain.document_loaders
import
UnstructuredFileLoader
,
TextLoader
,
CSVLoader
,
UnstructuredPDFLoader
,
\
UnstructuredWordDocumentLoader
,
PDFMinerPDFasHTMLLoader
from
.config
import
SENTENCE_SIZE
,
ZH_TITLE_ENHANCE
from
.config
import
SENTENCE_SIZE
,
ZH_TITLE_ENHANCE
from
.chinese_text_splitter
import
ChineseTextSplitter
from
.chinese_text_splitter
import
ChineseTextSplitter
from
.zh_title_enhance
import
zh_title_enhance
from
.zh_title_enhance
import
zh_title_enhance
from
langchain.schema
import
Document
from
langchain.schema
import
Document
from
typing
import
List
,
Dict
,
Optional
from
typing
import
List
,
Dict
,
Optional
from
src.loader.callback
import
BaseCallback
from
src.loader.callback
import
BaseCallback
import
re
import
re
from
bs4
import
BeautifulSoup
from
bs4
import
BeautifulSoup
def
load
(
filepath
,
mode
:
str
=
None
,
sentence_size
:
int
=
0
,
metadata
=
None
,
callbacks
=
None
,
**
kwargs
):
def
load
(
filepath
,
mode
:
str
=
None
,
sentence_size
:
int
=
0
,
metadata
=
None
,
callbacks
=
None
,
**
kwargs
):
r"""
r"""
加载文档,参数说明
加载文档,参数说明
mode:文档切割方式,"single", "elements", "paged"
mode:文档切割方式,"single", "elements", "paged"
...
@@ -19,37 +21,44 @@ def load(filepath,mode:str = None,sentence_size:int = 0,metadata = None,callback
...
@@ -19,37 +21,44 @@ def load(filepath,mode:str = None,sentence_size:int = 0,metadata = None,callback
kwargs
kwargs
"""
"""
if
filepath
.
lower
()
.
endswith
(
".md"
):
if
filepath
.
lower
()
.
endswith
(
".md"
):
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
elif
filepath
.
lower
()
.
endswith
(
".txt"
):
elif
filepath
.
lower
()
.
endswith
(
".txt"
):
loader
=
TextLoader
(
filepath
,
autodetect_encoding
=
True
,
**
kwargs
)
loader
=
TextLoader
(
filepath
,
autodetect_encoding
=
True
,
**
kwargs
)
elif
filepath
.
lower
()
.
endswith
(
".csv"
):
elif
filepath
.
lower
()
.
endswith
(
".csv"
):
loader
=
CSVLoader
(
filepath
,
**
kwargs
)
loader
=
CSVLoader
(
filepath
,
**
kwargs
)
elif
filepath
.
lower
()
.
endswith
(
".pdf"
):
elif
filepath
.
lower
()
.
endswith
(
".pdf"
):
# loader = UnstructuredPDFLoader(filepath, mode=mode or "elements",**kwargs)
# loader = UnstructuredPDFLoader(filepath, mode=mode or "elements",**kwargs)
# 使用自定义pdf loader
# 使用自定义pdf loader
return
__pdf_loader
(
filepath
,
sentence_size
=
sentence_size
,
metadata
=
metadata
,
callbacks
=
callbacks
)
return
__pdf_loader
(
filepath
,
sentence_size
=
sentence_size
,
metadata
=
metadata
,
callbacks
=
callbacks
)
elif
filepath
.
lower
()
.
endswith
(
".docx"
)
or
filepath
.
lower
()
.
endswith
(
".doc"
):
elif
filepath
.
lower
()
.
endswith
(
".docx"
)
or
filepath
.
lower
()
.
endswith
(
".doc"
):
loader
=
UnstructuredWordDocumentLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredWordDocumentLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
else
:
else
:
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
if
sentence_size
>
0
:
if
sentence_size
>
0
:
return
split
(
loader
.
load
(),
sentence_size
)
return
split
(
loader
.
load
(),
sentence_size
)
return
loader
.
load
()
return
loader
.
load
()
def
loads_path
(
path
:
str
,
**
kwargs
):
return
loads
(
get_files_in_directory
(
path
),
**
kwargs
)
def
loads
(
filepaths
,
**
kwargs
):
def
loads_path
(
path
:
str
,
**
kwargs
):
default_kwargs
=
{
"mode"
:
"paged"
}
return
loads
(
get_files_in_directory
(
path
),
**
kwargs
)
def
loads
(
filepaths
,
**
kwargs
):
default_kwargs
=
{
"mode"
:
"paged"
}
default_kwargs
.
update
(
**
kwargs
)
default_kwargs
.
update
(
**
kwargs
)
documents
=
[
load
(
filepath
=
file
,
**
default_kwargs
)
for
file
in
filepaths
]
documents
=
[
load
(
filepath
=
file
,
**
default_kwargs
)
for
file
in
filepaths
]
return
[
item
for
sublist
in
documents
for
item
in
sublist
]
return
[
item
for
sublist
in
documents
for
item
in
sublist
]
def
append
(
documents
:
List
[
Document
]
=
[],
sentence_size
:
int
=
SENTENCE_SIZE
):
#保留文档结构信息,注意处理hash
def
append
(
documents
=
None
,
sentence_size
:
int
=
SENTENCE_SIZE
):
# 保留文档结构信息,注意处理hash
if
documents
is
None
:
documents
=
[]
effect_documents
=
[]
effect_documents
=
[]
last_doc
=
documents
[
0
]
last_doc
=
documents
[
0
]
for
doc
in
documents
[
1
:]:
for
doc
in
documents
[
1
:]:
last_hash
=
""
if
"next_hash"
not
in
last_doc
.
metadata
else
last_doc
.
metadata
[
"next_hash"
]
last_hash
=
""
if
"next_hash"
not
in
last_doc
.
metadata
else
last_doc
.
metadata
[
"next_hash"
]
doc_hash
=
""
if
"next_hash"
not
in
doc
.
metadata
else
doc
.
metadata
[
"next_hash"
]
doc_hash
=
""
if
"next_hash"
not
in
doc
.
metadata
else
doc
.
metadata
[
"next_hash"
]
if
len
(
last_doc
.
page_content
)
+
len
(
doc
.
page_content
)
<=
sentence_size
and
last_hash
==
doc_hash
:
if
len
(
last_doc
.
page_content
)
+
len
(
doc
.
page_content
)
<=
sentence_size
and
last_hash
==
doc_hash
:
last_doc
.
page_content
=
last_doc
.
page_content
+
doc
.
page_content
last_doc
.
page_content
=
last_doc
.
page_content
+
doc
.
page_content
continue
continue
else
:
else
:
...
@@ -58,28 +67,31 @@ def append(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE):#保
...
@@ -58,28 +67,31 @@ def append(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE):#保
effect_documents
.
append
(
last_doc
)
effect_documents
.
append
(
last_doc
)
return
effect_documents
return
effect_documents
def
split
(
documents
:
List
[
Document
]
=
[],
sentence_size
:
int
=
SENTENCE_SIZE
):
#保留文档结构信息,注意处理hash
def
split
(
documents
=
None
,
sentence_size
:
int
=
SENTENCE_SIZE
):
# 保留文档结构信息,注意处理hash
if
documents
is
None
:
documents
=
[]
effect_documents
=
[]
effect_documents
=
[]
for
doc
in
documents
:
for
doc
in
documents
:
if
len
(
doc
.
page_content
)
>
sentence_size
:
if
len
(
doc
.
page_content
)
>
sentence_size
:
words_list
=
re
.
split
(
r'·-·'
,
doc
.
page_content
.
replace
(
"。"
,
"。·-·"
)
.
replace
(
"
\n
"
,
"
\n
·-·"
))
#
插入分隔符,分割
words_list
=
re
.
split
(
r'·-·'
,
doc
.
page_content
.
replace
(
"。"
,
"。·-·"
)
.
replace
(
"
\n
"
,
"
\n
·-·"
))
#
插入分隔符,分割
document
=
Document
(
page_content
=
""
,
metadata
=
copy
.
deepcopy
(
doc
.
metadata
))
document
=
Document
(
page_content
=
""
,
metadata
=
copy
.
deepcopy
(
doc
.
metadata
))
first
=
True
first
=
True
for
word
in
words_list
:
for
word
in
words_list
:
if
len
(
document
.
page_content
)
+
len
(
word
)
<
sentence_size
:
if
len
(
document
.
page_content
)
+
len
(
word
)
<
sentence_size
:
document
.
page_content
+=
word
document
.
page_content
+=
word
else
:
else
:
if
len
(
document
.
page_content
.
replace
(
" "
,
""
)
.
replace
(
"
\n
"
,
""
))
>
0
:
if
len
(
document
.
page_content
.
replace
(
" "
,
""
)
.
replace
(
"
\n
"
,
""
))
>
0
:
if
first
:
if
first
:
first
=
False
first
=
False
else
:
else
:
effect_documents
[
-
1
]
.
metadata
[
"next_doc"
]
=
document
.
page_content
effect_documents
[
-
1
]
.
metadata
[
"next_doc"
]
=
document
.
page_content
effect_documents
.
append
(
document
)
effect_documents
.
append
(
document
)
document
=
Document
(
page_content
=
word
,
metadata
=
copy
.
deepcopy
(
doc
.
metadata
))
document
=
Document
(
page_content
=
word
,
metadata
=
copy
.
deepcopy
(
doc
.
metadata
))
if
len
(
document
.
page_content
.
replace
(
" "
,
""
)
.
replace
(
"
\n
"
,
""
))
>
0
:
if
len
(
document
.
page_content
.
replace
(
" "
,
""
)
.
replace
(
"
\n
"
,
""
))
>
0
:
if
first
:
if
first
:
first
=
False
pass
else
:
else
:
effect_documents
[
-
1
]
.
metadata
[
"next_doc"
]
=
document
.
page_content
effect_documents
[
-
1
]
.
metadata
[
"next_doc"
]
=
document
.
page_content
effect_documents
.
append
(
document
)
effect_documents
.
append
(
document
)
...
@@ -87,10 +99,12 @@ def split(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE): #保
...
@@ -87,10 +99,12 @@ def split(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE): #保
effect_documents
.
append
(
doc
)
effect_documents
.
append
(
doc
)
return
effect_documents
return
effect_documents
def
load_file
(
filepath
,
sentence_size
=
SENTENCE_SIZE
,
using_zh_title_enhance
=
ZH_TITLE_ENHANCE
,
mode
:
str
=
None
,
**
kwargs
):
def
load_file
(
filepath
,
sentence_size
=
SENTENCE_SIZE
,
using_zh_title_enhance
=
ZH_TITLE_ENHANCE
,
mode
:
str
=
None
,
**
kwargs
):
print
(
"load_file"
,
filepath
)
print
(
"load_file"
,
filepath
)
if
filepath
.
lower
()
.
endswith
(
".md"
):
if
filepath
.
lower
()
.
endswith
(
".md"
):
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
docs
=
loader
.
load
()
docs
=
loader
.
load
()
elif
filepath
.
lower
()
.
endswith
(
".txt"
):
elif
filepath
.
lower
()
.
endswith
(
".txt"
):
loader
=
TextLoader
(
filepath
,
autodetect_encoding
=
True
,
**
kwargs
)
loader
=
TextLoader
(
filepath
,
autodetect_encoding
=
True
,
**
kwargs
)
...
@@ -100,15 +114,15 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
...
@@ -100,15 +114,15 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
loader
=
CSVLoader
(
filepath
,
**
kwargs
)
loader
=
CSVLoader
(
filepath
,
**
kwargs
)
docs
=
loader
.
load
()
docs
=
loader
.
load
()
elif
filepath
.
lower
()
.
endswith
(
".pdf"
):
elif
filepath
.
lower
()
.
endswith
(
".pdf"
):
loader
=
UnstructuredPDFLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredPDFLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
True
,
sentence_size
=
sentence_size
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
True
,
sentence_size
=
sentence_size
)
docs
=
loader
.
load_and_split
(
textsplitter
)
docs
=
loader
.
load_and_split
(
textsplitter
)
elif
filepath
.
lower
()
.
endswith
(
".docx"
):
elif
filepath
.
lower
()
.
endswith
(
".docx"
):
loader
=
UnstructuredWordDocumentLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredWordDocumentLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
False
,
sentence_size
=
sentence_size
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
False
,
sentence_size
=
sentence_size
)
docs
=
loader
.
load_and_split
(
textsplitter
)
docs
=
loader
.
load_and_split
(
textsplitter
)
else
:
else
:
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
False
,
sentence_size
=
sentence_size
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
False
,
sentence_size
=
sentence_size
)
docs
=
loader
.
load_and_split
(
text_splitter
=
textsplitter
)
docs
=
loader
.
load_and_split
(
text_splitter
=
textsplitter
)
if
using_zh_title_enhance
:
if
using_zh_title_enhance
:
...
@@ -116,6 +130,7 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
...
@@ -116,6 +130,7 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
write_check_file
(
filepath
,
docs
)
write_check_file
(
filepath
,
docs
)
return
docs
return
docs
def
write_check_file
(
filepath
,
docs
):
def
write_check_file
(
filepath
,
docs
):
folder_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
filepath
),
"tmp_files"
)
folder_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
filepath
),
"tmp_files"
)
if
not
os
.
path
.
exists
(
folder_path
):
if
not
os
.
path
.
exists
(
folder_path
):
...
@@ -129,6 +144,7 @@ def write_check_file(filepath, docs):
...
@@ -129,6 +144,7 @@ def write_check_file(filepath, docs):
fout
.
write
(
'
\n
'
)
fout
.
write
(
'
\n
'
)
fout
.
close
()
fout
.
close
()
def
get_files_in_directory
(
directory
):
def
get_files_in_directory
(
directory
):
file_paths
=
[]
file_paths
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
directory
):
for
root
,
dirs
,
files
in
os
.
walk
(
directory
):
...
@@ -137,21 +153,29 @@ def get_files_in_directory(directory):
...
@@ -137,21 +153,29 @@ def get_files_in_directory(directory):
file_paths
.
append
(
file_path
)
file_paths
.
append
(
file_path
)
return
file_paths
return
file_paths
#自定义pdf load部分
def
__checkV
(
strings
:
str
):
# 自定义pdf load部分
def
__checkV
(
strings
:
str
):
lines
=
len
(
strings
.
splitlines
())
lines
=
len
(
strings
.
splitlines
())
if
(
lines
>
3
and
len
(
strings
.
replace
(
" "
,
""
))
/
lines
<
15
)
:
if
lines
>
3
and
len
(
strings
.
replace
(
" "
,
""
))
/
lines
<
15
:
return
False
return
False
return
True
return
True
def
__isTitle
(
strings
:
str
):
return
len
(
strings
.
splitlines
())
==
1
and
len
(
strings
)
>
0
and
strings
.
endswith
(
"
\n
"
)
def
__appendPara
(
strings
:
str
):
def
__isTitle
(
strings
:
str
):
return
strings
.
replace
(
".
\n
"
,
"^_^"
)
.
replace
(
"。
\n
"
,
"^-^"
)
.
replace
(
"?
\n
"
,
"?^-^"
)
.
replace
(
"?
\n
"
,
"?^-^"
)
.
replace
(
"
\n
"
,
""
)
.
replace
(
"^_^"
,
".
\n
"
)
.
replace
(
"^-^"
,
"。
\n
"
)
.
replace
(
"?^-^"
,
"?
\n
"
)
.
replace
(
"?^-^"
,
"?
\n
"
)
return
len
(
strings
.
splitlines
())
==
1
and
len
(
strings
)
>
0
and
strings
.
endswith
(
"
\n
"
)
def
__check_fs_ff
(
line_ff_fs_s
,
fs
,
ff
):
#若当前行有上一行一样的字体、字号文字,则返回相同的。默认返回最长文本的字体和字号
def
__appendPara
(
strings
:
str
):
return
strings
.
replace
(
".
\n
"
,
"^_^"
)
.
replace
(
"。
\n
"
,
"^-^"
)
.
replace
(
"?
\n
"
,
"?^-^"
)
.
replace
(
"?
\n
"
,
"?^-^"
)
.
replace
(
"
\n
"
,
""
)
.
replace
(
"^_^"
,
".
\n
"
)
.
replace
(
"^-^"
,
"。
\n
"
)
.
replace
(
"?^-^"
,
"?
\n
"
)
.
replace
(
"?^-^"
,
"?
\n
"
)
def
__check_fs_ff
(
line_ff_fs_s
,
fs
,
ff
):
# 若当前行有上一行一样的字体、字号文字,则返回相同的。默认返回最长文本的字体和字号
re_fs
=
line_ff_fs_s
[
-
1
][
0
][
-
1
]
re_fs
=
line_ff_fs_s
[
-
1
][
0
][
-
1
]
re_ff
=
line_ff_fs_s
[
-
1
][
1
][
-
1
]
if
line_ff_fs_s
[
-
1
][
1
]
else
None
re_ff
=
line_ff_fs_s
[
-
1
][
1
][
-
1
]
if
line_ff_fs_s
[
-
1
][
1
]
else
None
max_len
=
0
max_len
=
0
for
ff_fs
in
line_ff_fs_s
:
#
寻找最长文本字体和字号
for
ff_fs
in
line_ff_fs_s
:
#
寻找最长文本字体和字号
c_max
=
max
(
list
(
map
(
int
,
ff_fs
[
0
])))
c_max
=
max
(
list
(
map
(
int
,
ff_fs
[
0
])))
if
max_len
<
ff_fs
[
2
]
or
(
max_len
==
ff_fs
[
2
]
and
c_max
>
int
(
re_fs
)):
if
max_len
<
ff_fs
[
2
]
or
(
max_len
==
ff_fs
[
2
]
and
c_max
>
int
(
re_fs
)):
max_len
=
ff_fs
[
2
]
max_len
=
ff_fs
[
2
]
...
@@ -163,122 +187,132 @@ def __check_fs_ff(line_ff_fs_s,fs,ff): #若当前行有上一行一样的字
...
@@ -163,122 +187,132 @@ def __check_fs_ff(line_ff_fs_s,fs,ff): #若当前行有上一行一样的字
re_fs
=
fs
re_fs
=
fs
re_ff
=
ff
re_ff
=
ff
break
break
return
int
(
re_fs
),
re_ff
return
int
(
re_fs
),
re_ff
def
append_document
(
snippets1
:
List
[
Document
],
title
:
str
,
content
:
str
,
callbacks
,
font_size
,
page_num
,
metadate
,
need_append
:
bool
=
False
):
def
append_document
(
snippets1
:
List
[
Document
],
title
:
str
,
content
:
str
,
callbacks
,
font_size
,
page_num
,
metadate
,
need_append
:
bool
=
False
):
if
callbacks
:
if
callbacks
:
for
cb
in
callbacks
:
for
cb
in
callbacks
:
if
isinstance
(
cb
,
BaseCallback
):
if
isinstance
(
cb
,
BaseCallback
):
if
cb
.
filter
(
title
,
content
):
if
cb
.
filter
(
title
,
content
):
return
return
if
need_append
and
len
(
snippets1
)
>
0
:
if
need_append
and
len
(
snippets1
)
>
0
:
ps
=
snippets1
.
pop
()
ps
=
snippets1
.
pop
()
snippets1
.
append
(
Document
(
page_content
=
ps
.
page_content
+
title
,
metadata
=
ps
.
metadata
))
snippets1
.
append
(
Document
(
page_content
=
ps
.
page_content
+
title
,
metadata
=
ps
.
metadata
))
else
:
else
:
doc_metadata
=
{
"font-size"
:
font_size
,
"page_number"
:
page_num
}
doc_metadata
=
{
"font-size"
:
font_size
,
"page_number"
:
page_num
}
doc_metadata
.
update
(
metadate
)
doc_metadata
.
update
(
metadate
)
snippets1
.
append
(
Document
(
page_content
=
title
+
content
,
metadata
=
doc_metadata
))
snippets1
.
append
(
Document
(
page_content
=
title
+
content
,
metadata
=
doc_metadata
))
'''
'''
提取pdf文档,按标题和内容进行分割,文档的页码按标题所在页码为准
提取pdf文档,按标题和内容进行分割,文档的页码按标题所在页码为准
分割后的文本按sentence_size值再次分割,分割的文本的页码均属于父文本的页码
分割后的文本按sentence_size值再次分割,分割的文本的页码均属于父文本的页码
'''
'''
def
__pdf_loader
(
filepath
:
str
,
sentence_size
:
int
=
0
,
metadata
=
None
,
callbacks
=
None
):
def
__pdf_loader
(
filepath
:
str
,
sentence_size
:
int
=
0
,
metadata
=
None
,
callbacks
=
None
):
if
not
filepath
.
lower
()
.
endswith
(
".pdf"
):
if
not
filepath
.
lower
()
.
endswith
(
".pdf"
):
raise
ValueError
(
"file is not pdf document"
)
raise
ValueError
(
"file is not pdf document"
)
loader
=
PDFMinerPDFasHTMLLoader
(
filepath
)
loader
=
PDFMinerPDFasHTMLLoader
(
filepath
)
documents
=
loader
.
load
()
documents
=
loader
.
load
()
soup
=
BeautifulSoup
(
documents
[
0
]
.
page_content
,
'html.parser'
)
soup
=
BeautifulSoup
(
documents
[
0
]
.
page_content
,
'html.parser'
)
content
=
soup
.
find_all
(
'div'
)
content
=
soup
.
find_all
(
'div'
)
cur_fs
=
None
#
当前文本font-size
cur_fs
=
None
#
当前文本font-size
last_fs
=
None
#
上一段文本font-size
last_fs
=
None
#
上一段文本font-size
cur_ff
=
None
#
当前文本风格
cur_ff
=
None
#
当前文本风格
cur_text
=
''
cur_text
=
''
fs_increasing
=
False
#
下一行字体变大,判断为标题,从此处分割
fs_increasing
=
False
#
下一行字体变大,判断为标题,从此处分割
last_text
=
''
last_text
=
''
last_page_num
=
1
#
上一页页码 根据page_split判断当前文本页码
last_page_num
=
1
#
上一页页码 根据page_split判断当前文本页码
page_num
=
1
#
初始页码
page_num
=
1
#
初始页码
page_change
=
False
#
页面切换
page_change
=
False
#
页面切换
page_split
=
False
#
页面是否出现文本分割
page_split
=
False
#
页面是否出现文本分割
last_is_title
=
False
#
上一个文本是否是标题
last_is_title
=
False
#
上一个文本是否是标题
snippets
:
List
[
Document
]
=
[]
snippets
:
List
[
Document
]
=
[]
filename
=
os
.
path
.
basename
(
filepath
)
filename
=
os
.
path
.
basename
(
filepath
)
if
metadata
:
if
metadata
:
metadata
.
update
({
'source'
:
filepath
,
'filename'
:
filename
,
'filetype'
:
'application/pdf'
})
metadata
.
update
({
'source'
:
filepath
,
'filename'
:
filename
,
'filetype'
:
'application/pdf'
})
else
:
else
:
metadata
=
{
'source'
:
filepath
,
'filename'
:
filename
,
'filetype'
:
'application/pdf'
}
metadata
=
{
'source'
:
filepath
,
'filename'
:
filename
,
'filetype'
:
'application/pdf'
}
for
c
in
content
:
for
c
in
content
:
divs
=
c
.
get
(
'style'
)
divs
=
c
.
get
(
'style'
)
if
re
.
match
(
r"^(Page|page)"
,
c
.
text
):
#
检测当前页的页码
if
re
.
match
(
r"^(Page|page)"
,
c
.
text
):
#
检测当前页的页码
match
=
re
.
match
(
r"^(page|Page)\s+(\d+)"
,
c
.
text
)
match
=
re
.
match
(
r"^(page|Page)\s+(\d+)"
,
c
.
text
)
if
match
:
if
match
:
if
page_split
:
#
如果有文本分割,则换页,没有则保持当前文本起始页码
if
page_split
:
#
如果有文本分割,则换页,没有则保持当前文本起始页码
last_page_num
=
page_num
last_page_num
=
page_num
page_num
=
match
.
group
(
2
)
page_num
=
match
.
group
(
2
)
if
len
(
last_text
)
+
len
(
cur_text
)
==
0
:
#
如果翻页且文本为空,上一页页码为当前页码
if
len
(
last_text
)
+
len
(
cur_text
)
==
0
:
#
如果翻页且文本为空,上一页页码为当前页码
last_page_num
=
page_num
last_page_num
=
page_num
page_change
=
True
page_change
=
True
page_split
=
False
page_split
=
False
continue
continue
if
re
.
findall
(
'writing-mode:(.*?);'
,
divs
)
==
[
'False'
]
or
re
.
match
(
r'^[0-9\s\n]+$'
,
c
.
text
)
or
re
.
match
(
r"^第\s+\d+\s+页$"
,
c
.
text
):
#如果不显示或者纯数字
if
re
.
findall
(
'writing-mode:(.*?);'
,
divs
)
==
[
'False'
]
or
re
.
match
(
r'^[0-9\s\n]+$'
,
c
.
text
)
or
re
.
match
(
r"^第\s+\d+\s+页$"
,
c
.
text
):
# 如果不显示或者纯数字
continue
continue
if
len
(
c
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
))
<=
1
:
#
去掉有效字符小于1的行
if
len
(
c
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
))
<=
1
:
#
去掉有效字符小于1的行
continue
continue
sps
=
c
.
find_all
(
'span'
)
sps
=
c
.
find_all
(
'span'
)
if
not
sps
:
if
not
sps
:
continue
continue
line_ff_fs_s
=
[]
#
有效字符大于1的集合
line_ff_fs_s
=
[]
#
有效字符大于1的集合
line_ff_fs_s2
=
[]
#
有效字符为1的集合
line_ff_fs_s2
=
[]
#
有效字符为1的集合
for
sp
in
sps
:
#
如果一行中有多个不同样式的
for
sp
in
sps
:
#
如果一行中有多个不同样式的
sp_len
=
len
(
sp
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
))
sp_len
=
len
(
sp
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
))
if
sp_len
>
0
:
if
sp_len
>
0
:
st
=
sp
.
get
(
'style'
)
st
=
sp
.
get
(
'style'
)
if
st
:
if
st
:
ff_fs
=
(
re
.
findall
(
'font-size:(
\
d+)px'
,
st
),
re
.
findall
(
'font-family:(.*?);'
,
st
),
len
(
sp
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
)))
ff_fs
=
(
re
.
findall
(
'font-size:(
\
d+)px'
,
st
),
re
.
findall
(
'font-family:(.*?);'
,
st
),
if
sp_len
==
1
:
#过滤一个有效字符的span
len
(
sp
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
)))
if
sp_len
==
1
:
# 过滤一个有效字符的span
line_ff_fs_s2
.
append
(
ff_fs
)
line_ff_fs_s2
.
append
(
ff_fs
)
else
:
else
:
line_ff_fs_s
.
append
(
ff_fs
)
line_ff_fs_s
.
append
(
ff_fs
)
if
len
(
line_ff_fs_s
)
==
0
:
#
如果为空,则以一个有效字符span为准
if
len
(
line_ff_fs_s
)
==
0
:
#
如果为空,则以一个有效字符span为准
if
len
(
line_ff_fs_s2
)
>
0
:
if
len
(
line_ff_fs_s2
)
>
0
:
line_ff_fs_s
=
line_ff_fs_s2
line_ff_fs_s
=
line_ff_fs_s2
else
:
else
:
if
len
(
c
.
text
)
>
0
:
if
len
(
c
.
text
)
>
0
:
page_change
=
False
page_change
=
False
continue
continue
fs
,
ff
=
__check_fs_ff
(
line_ff_fs_s
,
cur_fs
,
cur_ff
)
fs
,
ff
=
__check_fs_ff
(
line_ff_fs_s
,
cur_fs
,
cur_ff
)
if
not
cur_ff
:
if
not
cur_ff
:
cur_ff
=
ff
cur_ff
=
ff
if
not
cur_fs
:
if
not
cur_fs
:
cur_fs
=
fs
cur_fs
=
fs
if
(
abs
(
fs
-
cur_fs
)
<=
1
and
ff
==
cur_ff
):
#
风格和字体都没改变
if
abs
(
fs
-
cur_fs
)
<=
1
and
ff
==
cur_ff
:
#
风格和字体都没改变
cur_text
+=
c
.
text
cur_text
+=
c
.
text
cur_fs
=
fs
cur_fs
=
fs
page_change
=
False
page_change
=
False
if
len
(
cur_text
.
splitlines
())
>
3
:
#
连续多行则fs_increasing不再生效
if
len
(
cur_text
.
splitlines
())
>
3
:
#
连续多行则fs_increasing不再生效
fs_increasing
=
False
fs_increasing
=
False
else
:
else
:
if
page_change
and
cur_fs
>
fs
+
1
:
#
翻页,(字体变小) 大概率是页眉,跳过c.text。-----有可能切掉一行文本
if
page_change
and
cur_fs
>
fs
+
1
:
#
翻页,(字体变小) 大概率是页眉,跳过c.text。-----有可能切掉一行文本
page_change
=
False
page_change
=
False
continue
continue
if
last_is_title
:
#
如果上一个为title
if
last_is_title
:
#
如果上一个为title
if
__isTitle
(
cur_text
)
or
fs_increasing
:
#
连续多个title 或者 有变大标识的
if
__isTitle
(
cur_text
)
or
fs_increasing
:
#
连续多个title 或者 有变大标识的
last_text
=
last_text
+
cur_text
last_text
=
last_text
+
cur_text
last_is_title
=
True
last_is_title
=
True
fs_increasing
=
False
fs_increasing
=
False
else
:
else
:
append_document
(
snippets
,
last_text
,
__appendPara
(
cur_text
),
callbacks
,
cur_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
)
append_document
(
snippets
,
last_text
,
__appendPara
(
cur_text
),
callbacks
,
cur_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
)
page_split
=
True
page_split
=
True
last_text
=
''
last_text
=
''
last_is_title
=
False
last_is_title
=
False
fs_increasing
=
int
(
fs
)
>
int
(
cur_fs
)
#字体变大
fs_increasing
=
int
(
fs
)
>
int
(
cur_fs
)
#
字体变大
else
:
else
:
if
len
(
last_text
)
>
0
and
__checkV
(
last_text
):
#过滤部分文本
if
len
(
last_text
)
>
0
and
__checkV
(
last_text
):
# 过滤部分文本
#将跨页的两段或者行数较少的文本合并
# 将跨页的两段或者行数较少的文本合并
append_document
(
snippets
,
__appendPara
(
last_text
),
""
,
callbacks
,
last_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
,
need_append
=
len
(
last_text
.
splitlines
())
<=
2
or
page_change
)
append_document
(
snippets
,
__appendPara
(
last_text
),
""
,
callbacks
,
last_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
,
need_append
=
len
(
last_text
.
splitlines
())
<=
2
or
page_change
)
page_split
=
True
page_split
=
True
last_text
=
cur_text
last_text
=
cur_text
last_is_title
=
__isTitle
(
last_text
)
or
fs_increasing
last_is_title
=
__isTitle
(
last_text
)
or
fs_increasing
...
@@ -290,7 +324,8 @@ def __pdf_loader(filepath:str,sentence_size:int = 0,metadata = None,callbacks =
...
@@ -290,7 +324,8 @@ def __pdf_loader(filepath:str,sentence_size:int = 0,metadata = None,callbacks =
cur_ff
=
ff
cur_ff
=
ff
cur_text
=
c
.
text
cur_text
=
c
.
text
page_change
=
False
page_change
=
False
append_document
(
snippets
,
last_text
,
__appendPara
(
cur_text
),
callbacks
,
cur_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
)
append_document
(
snippets
,
last_text
,
__appendPara
(
cur_text
),
callbacks
,
cur_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
)
if
sentence_size
>
0
:
if
sentence_size
>
0
:
return
split
(
snippets
,
sentence_size
)
return
split
(
snippets
,
sentence_size
)
return
snippets
return
snippets
src/loader/zh_title_enhance.py
View file @
493cdd59
...
@@ -33,7 +33,7 @@ def is_possible_title(
...
@@ -33,7 +33,7 @@ def is_possible_title(
title_max_word_length
:
int
=
20
,
title_max_word_length
:
int
=
20
,
non_alpha_threshold
:
float
=
0.5
,
non_alpha_threshold
:
float
=
0.5
,
)
->
bool
:
)
->
bool
:
"""Checks to see if the text passes all
of
the checks for a valid title.
"""Checks to see if the text passes all the checks for a valid title.
Parameters
Parameters
----------
----------
...
...
src/pgdb/chat/c_db.py
View file @
493cdd59
import
psycopg2
import
psycopg2
from
psycopg2
import
OperationalError
,
InterfaceError
from
psycopg2
import
OperationalError
,
InterfaceError
class
UPostgresDB
:
class
UPostgresDB
:
'''
"""
psycopg2.connect(
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
host #指定连接数据库的主机名。
...
@@ -18,8 +19,9 @@ class UPostgresDB:
...
@@ -18,8 +19,9 @@ class UPostgresDB:
sslkey #指定私钥文件名。
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
sslcert #指定公钥文件名。
)
)
'''
"""
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
self
.
host
=
host
self
.
host
=
host
self
.
database
=
database
self
.
database
=
database
self
.
user
=
user
self
.
user
=
user
...
@@ -35,7 +37,7 @@ class UPostgresDB:
...
@@ -35,7 +37,7 @@ class UPostgresDB:
database
=
self
.
database
,
database
=
self
.
database
,
user
=
self
.
user
,
user
=
self
.
user
,
password
=
self
.
password
,
password
=
self
.
password
,
port
=
self
.
port
port
=
self
.
port
)
)
self
.
cur
=
self
.
conn
.
cursor
()
self
.
cur
=
self
.
conn
.
cursor
()
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -89,7 +91,6 @@ class UPostgresDB:
...
@@ -89,7 +91,6 @@ class UPostgresDB:
print
(
f
"重新执行sql语句再次出现错误: {type(e).__name__}: {e}"
)
print
(
f
"重新执行sql语句再次出现错误: {type(e).__name__}: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
def
search
(
self
,
query
,
params
=
None
):
def
search
(
self
,
query
,
params
=
None
):
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
self
.
connect
()
...
...
src/pgdb/chat/c_user_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
from
.c_db
import
UPostgresDB
import
json
import
json
TABLE_USER
=
"""
TABLE_USER
=
"""
DROP TABLE IF EXISTS "c_user";
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
CREATE TABLE c_user (
...
@@ -13,13 +14,14 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码';
...
@@ -13,13 +14,14 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
COMMENT ON TABLE "c_user" IS '用户表';
"""
"""
class
CUser
:
class
CUser
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
self
.
db
=
db
def
insert
(
self
,
value
):
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO c_user(user_id, account, password) VALUES (
%
s,
%
s,
%
s)"
query
=
f
"INSERT INTO c_user(user_id, account, password) VALUES (
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
]
))
def
create_table
(
self
):
def
create_table
(
self
):
query
=
TABLE_USER
query
=
TABLE_USER
...
...
src/pgdb/chat/chat_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
from
.c_db
import
UPostgresDB
import
json
import
json
TABLE_CHAT
=
"""
TABLE_CHAT
=
"""
DROP TABLE IF EXISTS "chat";
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
CREATE TABLE chat (
...
@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
...
@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
COMMENT ON TABLE "chat" IS '会话信息表';
"""
"""
class
Chat
:
class
Chat
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
self
.
db
=
db
...
@@ -24,7 +26,7 @@ class Chat:
...
@@ -24,7 +26,7 @@ class Chat:
# 插入数据
# 插入数据
def
insert
(
self
,
value
):
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (
%
s,
%
s,
%
s,
%
s)"
query
=
f
"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (
%
s,
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
]
))
# 创建表
# 创建表
def
create_table
(
self
):
def
create_table
(
self
):
...
...
src/pgdb/chat/turn_qa_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
from
.c_db
import
UPostgresDB
import
json
import
json
TABLE_CHAT
=
"""
TABLE_CHAT
=
"""
DROP TABLE IF EXISTS "turn_qa";
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
CREATE TABLE turn_qa (
...
@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,
...
@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
"""
"""
class
TurnQa
:
class
TurnQa
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
self
.
db
=
db
...
@@ -28,7 +30,7 @@ class TurnQa:
...
@@ -28,7 +30,7 @@ class TurnQa:
# 插入数据
# 插入数据
def
insert
(
self
,
value
):
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (
%
s,
%
s,
%
s,
%
s,
%
s,
%
s)"
query
=
f
"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (
%
s,
%
s,
%
s,
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
],
value
[
4
],
value
[
5
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
],
value
[
4
],
value
[
5
]
))
# 创建表
# 创建表
def
create_table
(
self
):
def
create_table
(
self
):
...
...
src/pgdb/knowledge/callback.py
View file @
493cdd59
...
@@ -4,22 +4,24 @@ from os import path
...
@@ -4,22 +4,24 @@ from os import path
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
import
json
import
json
from
typing
import
List
,
Any
,
Tuple
,
Dict
from
typing
import
List
,
Any
,
Tuple
,
Dict
from
langchain.schema
import
Document
from
langchain.schema
import
Document
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
,
str2hash_base64
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
,
str2hash_base64
class
DocumentCallback
(
ABC
):
class
DocumentCallback
(
ABC
):
@abstractmethod
#向量库储存前文档处理--
@abstractmethod
#
向量库储存前文档处理--
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
pass
pass
@abstractmethod
#向量库查询后文档处理--用于结构建立
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
List
[
Tuple
[
Document
,
float
]]:
#向量库查询后文档处理
@abstractmethod
# 向量库查询后文档处理--用于结构建立
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
\
List
[
Tuple
[
Document
,
float
]]:
# 向量库查询后文档处理
pass
pass
class
DefaultDocumentCallback
(
DocumentCallback
):
class
DefaultDocumentCallback
(
DocumentCallback
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
output_doc
=
[]
output_doc
=
[]
for
doc
in
documents
:
for
doc
in
documents
:
if
"next_doc"
in
doc
.
metadata
:
if
"next_doc"
in
doc
.
metadata
:
...
@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback):
...
@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback):
doc
.
metadata
.
pop
(
"next_doc"
)
doc
.
metadata
.
pop
(
"next_doc"
)
output_doc
.
append
(
doc
)
output_doc
.
append
(
doc
)
return
output_doc
return
output_doc
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
List
[
Tuple
[
Document
,
float
]]:
#向量库查询后文档处理
output_doc
:
List
[
Tuple
[
Document
,
float
]]
=
[]
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
\
List
[
Tuple
[
Document
,
float
]]:
# 向量库查询后文档处理
output_doc
:
List
[
Tuple
[
Document
,
float
]]
=
[]
exist_hash
=
[]
exist_hash
=
[]
for
doc
,
score
in
documents
:
for
doc
,
score
in
documents
:
print
(
exist_hash
)
print
(
exist_hash
)
dochash
=
str2hash_base64
(
doc
.
page_content
)
dochash
=
str2hash_base64
(
doc
.
page_content
)
if
dochash
in
exist_hash
:
if
dochash
in
exist_hash
:
continue
continue
else
:
else
:
exist_hash
.
append
(
dochash
)
exist_hash
.
append
(
dochash
)
output_doc
.
append
((
doc
,
score
))
output_doc
.
append
((
doc
,
score
))
if
len
(
output_doc
)
>
number
:
if
len
(
output_doc
)
>
number
:
return
output_doc
return
output_doc
fordoc
=
doc
fordoc
=
doc
while
(
"next_hash"
in
fordoc
.
metadata
)
:
while
"next_hash"
in
fordoc
.
metadata
:
if
len
(
fordoc
.
metadata
[
"next_hash"
])
>
0
:
if
len
(
fordoc
.
metadata
[
"next_hash"
])
>
0
:
if
fordoc
.
metadata
[
"next_hash"
]
in
exist_hash
:
if
fordoc
.
metadata
[
"next_hash"
]
in
exist_hash
:
break
break
else
:
else
:
...
@@ -50,7 +54,7 @@ class DefaultDocumentCallback(DocumentCallback):
...
@@ -50,7 +54,7 @@ class DefaultDocumentCallback(DocumentCallback):
content
=
docstore
.
TXT_DOC
.
search
(
fordoc
.
metadata
[
"next_hash"
])
content
=
docstore
.
TXT_DOC
.
search
(
fordoc
.
metadata
[
"next_hash"
])
if
content
:
if
content
:
fordoc
=
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
fordoc
=
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
output_doc
.
append
((
fordoc
,
score
))
output_doc
.
append
((
fordoc
,
score
))
if
len
(
output_doc
)
>
number
:
if
len
(
output_doc
)
>
number
:
return
output_doc
return
output_doc
else
:
else
:
...
...
src/pgdb/knowledge/k_db.py
View file @
493cdd59
import
psycopg2
import
psycopg2
class
PostgresDB
:
class
PostgresDB
:
'''
"""
psycopg2.connect(
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
host #指定连接数据库的主机名。
...
@@ -17,8 +18,9 @@ class PostgresDB:
...
@@ -17,8 +18,9 @@ class PostgresDB:
sslkey #指定私钥文件名。
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
sslcert #指定公钥文件名。
)
)
'''
"""
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
self
.
host
=
host
self
.
host
=
host
self
.
database
=
database
self
.
database
=
database
self
.
user
=
user
self
.
user
=
user
...
@@ -33,7 +35,7 @@ class PostgresDB:
...
@@ -33,7 +35,7 @@ class PostgresDB:
database
=
self
.
database
,
database
=
self
.
database
,
user
=
self
.
user
,
user
=
self
.
user
,
password
=
self
.
password
,
password
=
self
.
password
,
port
=
self
.
port
port
=
self
.
port
)
)
self
.
cur
=
self
.
conn
.
cursor
()
self
.
cur
=
self
.
conn
.
cursor
()
...
@@ -44,6 +46,7 @@ class PostgresDB:
...
@@ -44,6 +46,7 @@ class PostgresDB:
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
self
.
conn
.
rollback
()
def
execute_args
(
self
,
query
,
args
):
def
execute_args
(
self
,
query
,
args
):
try
:
try
:
self
.
cur
.
execute
(
query
,
args
)
self
.
cur
.
execute
(
query
,
args
)
...
...
src/pgdb/knowledge/pgsqldocstore.py
View file @
493cdd59
import
sys
import
sys
from
os
import
path
from
os
import
path
# 这里相当于把当前目录添加到pythonpath中
# 这里相当于把当前目录添加到pythonpath中
sys
.
path
.
append
(
path
.
dirname
(
path
.
abspath
(
__file__
)))
sys
.
path
.
append
(
path
.
dirname
(
path
.
abspath
(
__file__
)))
from
typing
import
List
,
Union
,
Dict
,
Optional
from
typing
import
List
,
Union
,
Dict
,
Optional
from
langchain.docstore.base
import
AddableMixin
,
Docstore
from
langchain.docstore.base
import
AddableMixin
,
Docstore
from
k_db
import
PostgresDB
from
k_db
import
PostgresDB
from
.txt_doc_table
import
TxtDoc
from
.txt_doc_table
import
TxtDoc
from
.vec_txt_table
import
TxtVector
from
.vec_txt_table
import
TxtVector
import
json
,
hashlib
,
base64
import
json
,
hashlib
,
base64
from
langchain.schema
import
Document
from
langchain.schema
import
Document
def
str2hash_base64
(
inp
:
str
)
->
str
:
def
str2hash_base64
(
input
:
str
)
->
str
:
# return f"%s" % hash(input)
# return f"%s" % hash(input)
return
base64
.
b64encode
(
hashlib
.
sha1
(
inp
ut
.
encode
())
.
digest
())
.
decode
()
return
base64
.
b64encode
(
hashlib
.
sha1
(
inp
.
encode
())
.
digest
())
.
decode
()
class
PgSqlDocstore
(
Docstore
,
AddableMixin
):
class
PgSqlDocstore
(
Docstore
,
AddableMixin
):
host
:
str
host
:
str
dbname
:
str
dbname
:
str
username
:
str
username
:
str
password
:
str
password
:
str
port
:
str
port
:
str
'''
'''
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
'''
'''
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
{
"host"
:
self
.
host
,
"dbname"
:
self
.
dbname
,
"username"
:
self
.
username
,
"password"
:
self
.
password
,
"port"
:
self
.
port
}
return
{
"host"
:
self
.
host
,
"dbname"
:
self
.
dbname
,
"username"
:
self
.
username
,
"password"
:
self
.
password
,
"port"
:
self
.
port
}
def
__setstate__
(
self
,
info
):
def
__setstate__
(
self
,
info
):
self
.
__init__
(
info
)
self
.
__init__
(
info
)
def
__init__
(
self
,
info
:
dict
,
reset
:
bool
=
False
):
def
__init__
(
self
,
info
:
dict
,
reset
:
bool
=
False
):
self
.
host
=
info
[
"host"
]
self
.
host
=
info
[
"host"
]
self
.
dbname
=
info
[
"dbname"
]
self
.
dbname
=
info
[
"dbname"
]
self
.
username
=
info
[
"username"
]
self
.
username
=
info
[
"username"
]
self
.
password
=
info
[
"password"
]
self
.
password
=
info
[
"password"
]
self
.
port
=
info
[
"port"
]
if
"port"
in
info
else
"5432"
;
self
.
port
=
info
[
"port"
]
if
"port"
in
info
else
"5432"
self
.
pgdb
=
PostgresDB
(
self
.
host
,
self
.
dbname
,
self
.
username
,
self
.
password
,
port
=
self
.
port
)
self
.
pgdb
=
PostgresDB
(
self
.
host
,
self
.
dbname
,
self
.
username
,
self
.
password
,
port
=
self
.
port
)
self
.
TXT_DOC
=
TxtDoc
(
self
.
pgdb
)
self
.
TXT_DOC
=
TxtDoc
(
self
.
pgdb
)
self
.
VEC_TXT
=
TxtVector
(
self
.
pgdb
)
self
.
VEC_TXT
=
TxtVector
(
self
.
pgdb
)
if
reset
:
if
reset
:
...
@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin):
...
@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin):
self
.
VEC_TXT
.
drop_table
()
self
.
VEC_TXT
.
drop_table
()
self
.
TXT_DOC
.
create_table
()
self
.
TXT_DOC
.
create_table
()
self
.
VEC_TXT
.
create_table
()
self
.
VEC_TXT
.
create_table
()
def
__sub_init__
(
self
):
def
__sub_init__
(
self
):
if
not
self
.
pgdb
.
conn
:
if
not
self
.
pgdb
.
conn
:
self
.
pgdb
.
connect
()
self
.
pgdb
.
connect
()
'''
'''
从本地库中查找向量对应的文本段落,封装成Document返回
从本地库中查找向量对应的文本段落,封装成Document返回
'''
'''
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
if
not
self
.
pgdb
.
conn
:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
self
.
__sub_init__
()
...
@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin):
...
@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin):
return
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
return
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
else
:
else
:
return
Document
()
return
Document
()
'''
'''
从本地库中删除向量对应的文本,批量删除
从本地库中删除向量对应的文本,批量删除
'''
'''
def
delete
(
self
,
ids
:
List
)
->
None
:
def
delete
(
self
,
ids
:
List
)
->
None
:
if
not
self
.
pgdb
.
conn
:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
self
.
__sub_init__
()
pids
=
[]
pids
=
[]
for
i
d
in
ids
:
for
i
tem
in
ids
:
anwser
=
self
.
VEC_TXT
.
search
(
i
d
)
anwser
=
self
.
VEC_TXT
.
search
(
i
tem
)
pids
.
append
(
anwser
[
0
])
pids
.
append
(
anwser
[
0
])
self
.
VEC_TXT
.
delete
(
ids
)
self
.
VEC_TXT
.
delete
(
ids
)
self
.
TXT_DOC
.
delete
(
pids
)
self
.
TXT_DOC
.
delete
(
pids
)
'''
'''
向本地库添加向量和文本信息
向本地库添加向量和文本信息
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
'''
'''
def
add
(
self
,
texts
:
Dict
[
str
,
Document
])
->
None
:
def
add
(
self
,
texts
:
Dict
[
str
,
Document
])
->
None
:
# for vec,doc in texts.items():
# for vec,doc in texts.items():
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
if
not
self
.
pgdb
.
conn
:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
self
.
__sub_init__
()
paragraph_hashs
=
[]
#
hash,text
paragraph_hashs
=
[]
#
hash,text
paragraph_txts
=
[]
paragraph_txts
=
[]
vec_inserts
=
[]
vec_inserts
=
[]
for
vec
,
doc
in
texts
.
items
():
for
vec
,
doc
in
texts
.
items
():
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
print
(
txt_hash
)
print
(
txt_hash
)
vec_inserts
.
append
((
vec
,
doc
.
page_content
,
txt_hash
))
vec_inserts
.
append
((
vec
,
doc
.
page_content
,
txt_hash
))
if
txt_hash
not
in
paragraph_hashs
:
if
txt_hash
not
in
paragraph_hashs
:
paragraph_hashs
.
append
(
txt_hash
)
paragraph_hashs
.
append
(
txt_hash
)
paragraph
=
doc
.
metadata
[
"paragraph"
]
paragraph
=
doc
.
metadata
[
"paragraph"
]
doc
.
metadata
.
pop
(
"paragraph"
)
doc
.
metadata
.
pop
(
"paragraph"
)
paragraph_txts
.
append
((
txt_hash
,
paragraph
,
json
.
dumps
(
doc
.
metadata
,
ensure_ascii
=
False
)))
paragraph_txts
.
append
((
txt_hash
,
paragraph
,
json
.
dumps
(
doc
.
metadata
,
ensure_ascii
=
False
)))
# print(paragraph_txts)
# print(paragraph_txts)
self
.
TXT_DOC
.
insert
(
paragraph_txts
)
self
.
TXT_DOC
.
insert
(
paragraph_txts
)
self
.
VEC_TXT
.
insert
(
vec_inserts
)
self
.
VEC_TXT
.
insert
(
vec_inserts
)
...
@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin):
...
@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin):
class
InMemorySecondaryDocstore
(
Docstore
,
AddableMixin
):
class
InMemorySecondaryDocstore
(
Docstore
,
AddableMixin
):
"""Simple in memory docstore in the form of a dict."""
"""Simple in memory docstore in the form of a dict."""
def
__init__
(
self
,
_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
,
_sec_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
):
def
__init__
(
self
,
_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
,
_sec_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
):
"""Initialize with dict."""
"""Initialize with dict."""
self
.
_dict
=
_dict
if
_dict
is
not
None
else
{}
self
.
_dict
=
_dict
if
_dict
is
not
None
else
{}
self
.
_sec_dict
=
_sec_dict
if
_sec_dict
is
not
None
else
{}
self
.
_sec_dict
=
_sec_dict
if
_sec_dict
is
not
None
else
{}
...
@@ -126,14 +135,14 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
...
@@ -126,14 +135,14 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
dict1
=
{}
dict1
=
{}
dict_sec
=
{}
dict_sec
=
{}
for
vec
,
doc
in
texts
.
items
():
for
vec
,
doc
in
texts
.
items
():
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
metadata
=
doc
.
metadata
metadata
=
doc
.
metadata
paragraph
=
metadata
.
pop
(
'paragraph'
)
paragraph
=
metadata
.
pop
(
'paragraph'
)
# metadata.update({"paragraph_id":txt_hash})
# metadata.update({"paragraph_id":txt_hash})
metadata
[
'paragraph_id'
]
=
txt_hash
metadata
[
'paragraph_id'
]
=
txt_hash
dict_sec
[
txt_hash
]
=
Document
(
page_content
=
paragraph
,
metadata
=
metadata
)
dict_sec
[
txt_hash
]
=
Document
(
page_content
=
paragraph
,
metadata
=
metadata
)
dict1
[
vec
]
=
Document
(
page_content
=
doc
.
page_content
,
metadata
=
{
'paragraph_id'
:
txt_hash
})
dict1
[
vec
]
=
Document
(
page_content
=
doc
.
page_content
,
metadata
=
{
'paragraph_id'
:
txt_hash
})
self
.
_dict
=
{
**
self
.
_dict
,
**
dict1
}
self
.
_dict
=
{
**
self
.
_dict
,
**
dict1
}
self
.
_sec_dict
=
{
**
self
.
_sec_dict
,
**
dict_sec
}
self
.
_sec_dict
=
{
**
self
.
_sec_dict
,
**
dict_sec
}
...
@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
...
@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if
not
overlapping
:
if
not
overlapping
:
raise
ValueError
(
f
"Tried to delete ids that does not exist: {ids}"
)
raise
ValueError
(
f
"Tried to delete ids that does not exist: {ids}"
)
for
_id
in
ids
:
for
_id
in
ids
:
self
.
_sec_dict
.
pop
(
self
.
_dict
[
id
]
.
metadata
[
'paragraph_id'
])
self
.
_sec_dict
.
pop
(
self
.
_dict
[
_
id
]
.
metadata
[
'paragraph_id'
])
self
.
_dict
.
pop
(
_id
)
self
.
_dict
.
pop
(
_id
)
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
...
...
src/pgdb/knowledge/similarity.py
View file @
493cdd59
import
os
,
sys
import
os
,
sys
import
re
,
time
import
re
,
time
from
os
import
path
from
os
import
path
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
import
copy
import
copy
from
typing
import
List
,
OrderedDict
,
Any
,
Optional
,
Tuple
,
Dict
from
typing
import
List
,
OrderedDict
,
Any
,
Optional
,
Tuple
,
Dict
from
src.pgdb.knowledge.pgsqldocstore
import
InMemorySecondaryDocstore
from
src.pgdb.knowledge.pgsqldocstore
import
InMemorySecondaryDocstore
from
langchain.vectorstores.faiss
import
FAISS
,
dependable_faiss_import
from
langchain.vectorstores.faiss
import
FAISS
,
dependable_faiss_import
from
langchain.schema
import
Document
from
langchain.schema
import
Document
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
from
langchain.embeddings.huggingface
import
(
from
langchain.embeddings.huggingface
import
(
...
@@ -22,43 +22,54 @@ from langchain.callbacks.manager import (
...
@@ -22,43 +22,54 @@ from langchain.callbacks.manager import (
)
)
from
src.loader
import
load
from
src.loader
import
load
from
langchain.embeddings.base
import
Embeddings
from
langchain.embeddings.base
import
Embeddings
from
src.pgdb.knowledge.callback
import
DocumentCallback
,
DefaultDocumentCallback
from
src.pgdb.knowledge.callback
import
DocumentCallback
,
DefaultDocumentCallback
def
singleton
(
cls
):
def
singleton
(
cls
):
instances
=
{}
instances
=
{}
def
get_instance
(
*
args
,
**
kwargs
):
def
get_instance
(
*
args
,
**
kwargs
):
if
cls
not
in
instances
:
if
cls
not
in
instances
:
instances
[
cls
]
=
cls
(
*
args
,
**
kwargs
)
instances
[
cls
]
=
cls
(
*
args
,
**
kwargs
)
return
instances
[
cls
]
return
instances
[
cls
]
return
get_instance
return
get_instance
@singleton
@singleton
class
EmbeddingFactory
:
class
EmbeddingFactory
:
def
__init__
(
self
,
path
:
str
):
def
__init__
(
self
,
path
:
str
):
self
.
path
=
path
self
.
path
=
path
self
.
embedding
=
HuggingFaceEmbeddings
(
model_name
=
path
)
self
.
embedding
=
HuggingFaceEmbeddings
(
model_name
=
path
)
def
get_embedding
(
self
):
def
get_embedding
(
self
):
return
self
.
embedding
return
self
.
embedding
def
GetEmbding
(
path
:
str
)
->
Embeddings
:
def
GetEmbding
(
_path
:
str
)
->
Embeddings
:
# return HuggingFaceEmbeddings(model_name=path)
# return HuggingFaceEmbeddings(model_name=path)
return
EmbeddingFactory
(
path
)
.
get_embedding
()
return
EmbeddingFactory
(
_path
)
.
get_embedding
()
import
operator
import
operator
from
langchain.vectorstores.utils
import
DistanceStrategy
from
langchain.vectorstores.utils
import
DistanceStrategy
import
numpy
as
np
import
numpy
as
np
class
RE_FAISS
(
FAISS
):
class
RE_FAISS
(
FAISS
):
#去重,并保留metadate
# 去重,并保留metadate
def
_tuple_deduplication
(
self
,
tuple_input
:
List
[
Tuple
[
Document
,
float
]])
->
List
[
Tuple
[
Document
,
float
]]:
@staticmethod
def
_tuple_deduplication
(
tuple_input
:
List
[
Tuple
[
Document
,
float
]])
->
List
[
Tuple
[
Document
,
float
]]:
deduplicated_dict
=
OrderedDict
()
deduplicated_dict
=
OrderedDict
()
for
doc
,
scores
in
tuple_input
:
for
doc
,
scores
in
tuple_input
:
page_content
=
doc
.
page_content
page_content
=
doc
.
page_content
metadata
=
doc
.
metadata
metadata
=
doc
.
metadata
if
page_content
not
in
deduplicated_dict
:
if
page_content
not
in
deduplicated_dict
:
deduplicated_dict
[
page_content
]
=
(
metadata
,
scores
)
deduplicated_dict
[
page_content
]
=
(
metadata
,
scores
)
deduplicated_documents
=
[(
Document
(
page_content
=
key
,
metadata
=
value
[
0
]),
value
[
1
])
for
key
,
value
in
deduplicated_dict
.
items
()]
deduplicated_documents
=
[(
Document
(
page_content
=
key
,
metadata
=
value
[
0
]),
value
[
1
])
for
key
,
value
in
deduplicated_dict
.
items
()]
return
deduplicated_documents
return
deduplicated_documents
def
similarity_search_with_score_by_vector
(
def
similarity_search_with_score_by_vector
(
self
,
self
,
embedding
:
List
[
float
],
embedding
:
List
[
float
],
...
@@ -107,8 +118,9 @@ class RE_FAISS(FAISS):
...
@@ -107,8 +118,9 @@ class RE_FAISS(FAISS):
if
"doc_callback"
in
kwargs
:
if
"doc_callback"
in
kwargs
:
if
hasattr
(
kwargs
[
"doc_callback"
],
'after_search'
):
if
hasattr
(
kwargs
[
"doc_callback"
],
'after_search'
):
docs
=
kwargs
[
"doc_callback"
]
.
after_search
(
self
.
docstore
,
docs
,
number
=
k
)
docs
=
kwargs
[
"doc_callback"
]
.
after_search
(
self
.
docstore
,
docs
,
number
=
k
)
return
docs
[:
k
]
return
docs
[:
k
]
def
max_marginal_relevance_search_by_vector
(
def
max_marginal_relevance_search_by_vector
(
self
,
self
,
embedding
:
List
[
float
],
embedding
:
List
[
float
],
...
@@ -141,50 +153,61 @@ class RE_FAISS(FAISS):
...
@@ -141,50 +153,61 @@ class RE_FAISS(FAISS):
docs_and_scores
=
self
.
_tuple_deduplication
(
docs_and_scores
)
docs_and_scores
=
self
.
_tuple_deduplication
(
docs_and_scores
)
if
"doc_callback"
in
kwargs
:
if
"doc_callback"
in
kwargs
:
if
hasattr
(
kwargs
[
"doc_callback"
],
'after_search'
):
if
hasattr
(
kwargs
[
"doc_callback"
],
'after_search'
):
docs_and_scores
=
kwargs
[
"doc_callback"
]
.
after_search
(
self
.
docstore
,
docs_and_scores
,
number
=
k
)
docs_and_scores
=
kwargs
[
"doc_callback"
]
.
after_search
(
self
.
docstore
,
docs_and_scores
,
number
=
k
)
return
[
doc
for
doc
,
_
in
docs_and_scores
]
return
[
doc
for
doc
,
_
in
docs_and_scores
]
def
getFAISS
(
embedding_model_name
:
str
,
store_path
:
str
,
info
:
dict
=
None
,
index_name
:
str
=
"index"
,
is_pgsql
:
bool
=
True
,
reset
:
bool
=
False
)
->
RE_FAISS
:
embeddings
=
GetEmbding
(
path
=
embedding_model_name
)
def
getFAISS
(
embedding_model_name
:
str
,
store_path
:
str
,
info
:
dict
=
None
,
index_name
:
str
=
"index"
,
docstore1
:
PgSqlDocstore
=
None
is_pgsql
:
bool
=
True
,
reset
:
bool
=
False
)
->
RE_FAISS
:
embeddings
=
GetEmbding
(
_path
=
embedding_model_name
)
docstore1
:
PgSqlDocstore
=
None
if
is_pgsql
:
if
is_pgsql
:
if
info
and
"host"
in
info
and
"dbname"
in
info
and
"username"
in
info
and
"password"
in
info
:
if
info
and
"host"
in
info
and
"dbname"
in
info
and
"username"
in
info
and
"password"
in
info
:
docstore1
=
PgSqlDocstore
(
info
,
reset
=
reset
)
docstore1
=
PgSqlDocstore
(
info
,
reset
=
reset
)
else
:
else
:
docstore1
=
InMemorySecondaryDocstore
()
docstore1
=
InMemorySecondaryDocstore
()
if
not
path
.
exists
(
store_path
):
if
not
path
.
exists
(
store_path
):
os
.
makedirs
(
store_path
,
exist_ok
=
True
)
os
.
makedirs
(
store_path
,
exist_ok
=
True
)
if
store_path
is
None
or
len
(
store_path
)
<=
0
or
not
path
.
exists
(
path
.
join
(
store_path
,
index_name
+
".faiss"
))
or
reset
:
if
store_path
is
None
or
len
(
store_path
)
<=
0
or
not
path
.
exists
(
path
.
join
(
store_path
,
index_name
+
".faiss"
))
or
reset
:
print
(
"create new faiss"
)
print
(
"create new faiss"
)
index
=
faiss
.
IndexFlatL2
(
len
(
embeddings
.
embed_documents
([
"a"
])[
0
]))
#根据embeddings向量维度设置
index
=
faiss
.
IndexFlatL2
(
len
(
embeddings
.
embed_documents
([
"a"
])[
0
]))
# 根据embeddings向量维度设置
return
RE_FAISS
(
embedding_function
=
embeddings
.
client
.
encode
,
index
=
index
,
docstore
=
docstore1
,
index_to_docstore_id
=
{})
return
RE_FAISS
(
embedding_function
=
embeddings
.
client
.
encode
,
index
=
index
,
docstore
=
docstore1
,
index_to_docstore_id
=
{})
else
:
else
:
print
(
"load_local faiss"
)
print
(
"load_local faiss"
)
_faiss
=
RE_FAISS
.
load_local
(
folder_path
=
store_path
,
index_name
=
index_name
,
embeddings
=
embeddings
)
_faiss
=
RE_FAISS
.
load_local
(
folder_path
=
store_path
,
index_name
=
index_name
,
embeddings
=
embeddings
)
if
docstore1
and
is_pgsql
:
#如果外部参数调整,更新docstore
if
docstore1
and
is_pgsql
:
#
如果外部参数调整,更新docstore
_faiss
.
docstore
=
docstore1
_faiss
.
docstore
=
docstore1
return
_faiss
return
_faiss
class
VectorStore_FAISS
(
FAISS
):
class
VectorStore_FAISS
(
FAISS
):
def
__init__
(
self
,
embedding_model_name
:
str
,
store_path
:
str
,
index_name
:
str
=
"index"
,
info
:
dict
=
None
,
is_pgsql
:
bool
=
True
,
show_number
=
5
,
threshold
=
0.8
,
reset
:
bool
=
False
,
doc_callback
:
DocumentCallback
=
DefaultDocumentCallback
()):
def
__init__
(
self
,
embedding_model_name
:
str
,
store_path
:
str
,
index_name
:
str
=
"index"
,
info
:
dict
=
None
,
is_pgsql
:
bool
=
True
,
show_number
=
5
,
threshold
=
0.8
,
reset
:
bool
=
False
,
doc_callback
:
DocumentCallback
=
DefaultDocumentCallback
()):
self
.
info
=
info
self
.
info
=
info
self
.
embedding_model_name
=
embedding_model_name
self
.
embedding_model_name
=
embedding_model_name
self
.
store_path
=
path
.
join
(
store_path
,
index_name
)
self
.
store_path
=
path
.
join
(
store_path
,
index_name
)
if
not
path
.
exists
(
self
.
store_path
):
if
not
path
.
exists
(
self
.
store_path
):
os
.
makedirs
(
self
.
store_path
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
store_path
,
exist_ok
=
True
)
self
.
index_name
=
index_name
self
.
index_name
=
index_name
self
.
show_number
=
show_number
self
.
show_number
=
show_number
self
.
search_number
=
self
.
show_number
*
3
self
.
search_number
=
self
.
show_number
*
3
self
.
threshold
=
threshold
self
.
threshold
=
threshold
self
.
_faiss
=
getFAISS
(
self
.
embedding_model_name
,
self
.
store_path
,
info
=
info
,
index_name
=
self
.
index_name
,
is_pgsql
=
is_pgsql
,
reset
=
reset
)
self
.
_faiss
=
getFAISS
(
self
.
embedding_model_name
,
self
.
store_path
,
info
=
info
,
index_name
=
self
.
index_name
,
is_pgsql
=
is_pgsql
,
reset
=
reset
)
self
.
doc_callback
=
doc_callback
self
.
doc_callback
=
doc_callback
def
get_text_similarity_with_score
(
self
,
text
:
str
,
**
kwargs
):
def
get_text_similarity_with_score
(
self
,
text
:
str
,
**
kwargs
):
score_threshold
=
(
1
-
self
.
threshold
)
*
math
.
sqrt
(
2
)
score_threshold
=
(
1
-
self
.
threshold
)
*
math
.
sqrt
(
2
)
docs
=
self
.
_faiss
.
similarity_search_with_score
(
query
=
text
,
k
=
self
.
search_number
,
score_threshold
=
score_threshold
,
doc_callback
=
self
.
doc_callback
,
**
kwargs
)
docs
=
self
.
_faiss
.
similarity_search_with_score
(
query
=
text
,
k
=
self
.
search_number
,
score_threshold
=
score_threshold
,
doc_callback
=
self
.
doc_callback
,
**
kwargs
)
return
[
doc
for
doc
,
similarity
in
docs
][:
self
.
show_number
]
return
[
doc
for
doc
,
similarity
in
docs
][:
self
.
show_number
]
def
get_text_similarity
(
self
,
text
:
str
,
**
kwargs
):
def
get_text_similarity
(
self
,
text
:
str
,
**
kwargs
):
docs
=
self
.
_faiss
.
similarity_search
(
query
=
text
,
k
=
self
.
search_number
,
doc_callback
=
self
.
doc_callback
,
**
kwargs
)
docs
=
self
.
_faiss
.
similarity_search
(
query
=
text
,
k
=
self
.
search_number
,
doc_callback
=
self
.
doc_callback
,
**
kwargs
)
return
docs
[:
self
.
show_number
]
return
docs
[:
self
.
show_number
]
# #去重,并保留metadate
# #去重,并保留metadate
...
@@ -199,22 +222,25 @@ class VectorStore_FAISS(FAISS):
...
@@ -199,22 +222,25 @@ class VectorStore_FAISS(FAISS):
# deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
# deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
# return deduplicated_documents
# return deduplicated_documents
def
_join_document
(
self
,
docs
:
List
[
Document
])
->
str
:
@staticmethod
def
_join_document
(
docs
:
List
[
Document
])
->
str
:
print
(
docs
)
print
(
docs
)
return
""
.
join
([
doc
.
page_content
for
doc
in
docs
])
return
""
.
join
([
doc
.
page_content
for
doc
in
docs
])
def
get_local_doc
(
self
,
docs
:
List
[
Document
]):
@staticmethod
def
get_local_doc
(
docs
:
List
[
Document
]):
ans
=
[]
ans
=
[]
for
doc
in
docs
:
for
doc
in
docs
:
ans
.
append
({
"page_content"
:
doc
.
page_content
,
"page_number"
:
doc
.
metadata
[
"page_number"
],
"filename"
:
doc
.
metadata
[
"filename"
]})
ans
.
append
({
"page_content"
:
doc
.
page_content
,
"page_number"
:
doc
.
metadata
[
"page_number"
],
"filename"
:
doc
.
metadata
[
"filename"
]})
return
ans
return
ans
# def _join_document_location(self, docs:List[Document]) -> str:
# def _join_document_location(self, docs:List[Document]) -> str:
# 持久化到本地
# 持久化到本地
def
_save_local
(
self
):
def
_save_local
(
self
):
self
.
_faiss
.
save_local
(
folder_path
=
self
.
store_path
,
index_name
=
self
.
index_name
)
self
.
_faiss
.
save_local
(
folder_path
=
self
.
store_path
,
index_name
=
self
.
index_name
)
# 添加文档
# 添加文档
# Document {
# Document {
# page_content 段落
# page_content 段落
...
@@ -222,10 +248,10 @@ class VectorStore_FAISS(FAISS):
...
@@ -222,10 +248,10 @@ class VectorStore_FAISS(FAISS):
# page 页码
# page 页码
# }
# }
# }
# }
def
_add_documents
(
self
,
new_docs
:
List
[
Document
],
need_split
:
bool
=
True
,
pattern
:
str
=
r'[?。;\n]'
):
def
_add_documents
(
self
,
new_docs
:
List
[
Document
],
need_split
:
bool
=
True
,
pattern
:
str
=
r'[?。;\n]'
):
list_of_documents
:
List
[
Document
]
=
[]
list_of_documents
:
List
[
Document
]
=
[]
if
self
.
doc_callback
:
if
self
.
doc_callback
:
new_docs
=
self
.
doc_callback
.
before_store
(
self
.
_faiss
.
docstore
,
new_docs
)
new_docs
=
self
.
doc_callback
.
before_store
(
self
.
_faiss
.
docstore
,
new_docs
)
if
need_split
:
if
need_split
:
for
doc
in
new_docs
:
for
doc
in
new_docs
:
words_list
=
re
.
split
(
pattern
,
doc
.
page_content
)
words_list
=
re
.
split
(
pattern
,
doc
.
page_content
)
...
@@ -240,8 +266,14 @@ class VectorStore_FAISS(FAISS):
...
@@ -240,8 +266,14 @@ class VectorStore_FAISS(FAISS):
else
:
else
:
list_of_documents
=
new_docs
list_of_documents
=
new_docs
self
.
_faiss
.
add_documents
(
list_of_documents
)
self
.
_faiss
.
add_documents
(
list_of_documents
)
def
_add_documents_from_dir
(
self
,
filepaths
=
[],
load_kwargs
:
Optional
[
dict
]
=
{
"mode"
:
"paged"
}):
self
.
_add_documents
(
load
.
loads
(
filepaths
,
**
load_kwargs
))
def
_add_documents_from_dir
(
self
,
filepaths
=
None
,
load_kwargs
=
None
):
if
load_kwargs
is
None
:
load_kwargs
=
{
"mode"
:
"paged"
}
if
filepaths
is
None
:
filepaths
=
[]
self
.
_add_documents
(
load
.
loads
(
filepaths
,
**
load_kwargs
))
def
as_retriever
(
self
,
**
kwargs
:
Any
)
->
VectorStoreRetriever
:
def
as_retriever
(
self
,
**
kwargs
:
Any
)
->
VectorStoreRetriever
:
"""
"""
Return VectorStoreRetriever initialized from this VectorStore.
Return VectorStoreRetriever initialized from this VectorStore.
...
@@ -303,11 +335,11 @@ class VectorStore_FAISS(FAISS):
...
@@ -303,11 +335,11 @@ class VectorStore_FAISS(FAISS):
default_kwargs
.
update
(
kwargs
[
"search_kwargs"
])
default_kwargs
.
update
(
kwargs
[
"search_kwargs"
])
kwargs
[
"search_kwargs"
]
=
default_kwargs
kwargs
[
"search_kwargs"
]
=
default_kwargs
elif
"similarity_score_threshold"
==
kwargs
[
"search_type"
]:
elif
"similarity_score_threshold"
==
kwargs
[
"search_type"
]:
default_kwargs
=
{
'score_threshold'
:
self
.
threshold
,
'k'
:
self
.
show_number
}
default_kwargs
=
{
'score_threshold'
:
self
.
threshold
,
'k'
:
self
.
show_number
}
if
"search_kwargs"
in
kwargs
:
if
"search_kwargs"
in
kwargs
:
default_kwargs
.
update
(
kwargs
[
"search_kwargs"
])
default_kwargs
.
update
(
kwargs
[
"search_kwargs"
])
kwargs
[
"search_kwargs"
]
=
default_kwargs
kwargs
[
"search_kwargs"
]
=
default_kwargs
kwargs
[
"search_kwargs"
][
"doc_callback"
]
=
self
.
doc_callback
kwargs
[
"search_kwargs"
][
"doc_callback"
]
=
self
.
doc_callback
tags
=
kwargs
.
pop
(
"tags"
,
None
)
or
[]
tags
=
kwargs
.
pop
(
"tags"
,
None
)
or
[]
tags
.
extend
(
self
.
_faiss
.
_get_retriever_tags
())
tags
.
extend
(
self
.
_faiss
.
_get_retriever_tags
())
print
(
kwargs
)
print
(
kwargs
)
...
@@ -316,20 +348,21 @@ class VectorStore_FAISS(FAISS):
...
@@ -316,20 +348,21 @@ class VectorStore_FAISS(FAISS):
class
VectorStoreRetriever_FAISS
(
VectorStoreRetriever
):
class
VectorStoreRetriever_FAISS
(
VectorStoreRetriever
):
search_k
=
5
search_k
=
5
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
()
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
if
"k"
in
self
.
search_kwargs
:
if
"k"
in
self
.
search_kwargs
:
self
.
search_k
=
self
.
search_kwargs
[
"k"
]
self
.
search_k
=
self
.
search_kwargs
[
"k"
]
self
.
search_kwargs
[
"k"
]
=
self
.
search_k
*
2
self
.
search_kwargs
[
"k"
]
=
self
.
search_k
*
2
def
_get_relevant_documents
(
def
_get_relevant_documents
(
self
,
query
:
str
,
*
,
run_manager
:
CallbackManagerForRetrieverRun
self
,
query
:
str
,
*
,
run_manager
:
CallbackManagerForRetrieverRun
)
->
List
[
Document
]:
)
->
List
[
Document
]:
docs
=
super
()
.
_get_relevant_documents
(
query
=
query
,
run_manager
=
run_manager
)
docs
=
super
()
.
_get_relevant_documents
(
query
=
query
,
run_manager
=
run_manager
)
return
docs
[:
self
.
search_k
]
return
docs
[:
self
.
search_k
]
async
def
_aget_relevant_documents
(
async
def
_aget_relevant_documents
(
self
,
query
:
str
,
*
,
run_manager
:
AsyncCallbackManagerForRetrieverRun
self
,
query
:
str
,
*
,
run_manager
:
AsyncCallbackManagerForRetrieverRun
)
->
List
[
Document
]:
)
->
List
[
Document
]:
docs
=
super
()
.
_aget_relevant_documents
(
query
=
query
,
run_manager
=
run_manager
)
docs
=
super
()
.
_aget_relevant_documents
(
query
=
query
,
run_manager
=
run_manager
)
return
docs
[:
self
.
search_k
]
return
docs
[:
self
.
search_k
]
\ No newline at end of file
src/pgdb/knowledge/txt_doc_table.py
View file @
493cdd59
from
.k_db
import
PostgresDB
from
.k_db
import
PostgresDB
# paragraph_id BIGSERIAL primary key,
# paragraph_id BIGSERIAL primary key,
TABLE_TXT_DOC
=
"""
TABLE_TXT_DOC
=
"""
create table txt_doc (
create table txt_doc (
hash varchar(40) primary key,
hash varchar(40) primary key,
...
@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """
...
@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
"""
"""
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
class
TxtDoc
:
class
TxtDoc
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
...
@@ -21,19 +24,20 @@ class TxtDoc:
...
@@ -21,19 +24,20 @@ class TxtDoc:
args
=
[]
args
=
[]
for
value
in
texts
:
for
value
in
texts
:
value
=
list
(
value
)
value
=
list
(
value
)
query
+=
"(
%
s,
%
s,
%
s),"
query
+=
"(
%
s,
%
s,
%
s),"
args
.
extend
(
value
)
args
.
extend
(
value
)
query
=
query
[:
len
(
query
)
-
1
]
query
=
query
[:
len
(
query
)
-
1
]
query
+=
f
"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
query
+=
f
"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
self
.
db
.
execute_args
(
query
,
args
)
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
def
delete
(
self
,
ids
):
for
i
d
in
ids
:
for
i
tem
in
ids
:
query
=
f
"delete FROM txt_doc WHERE hash =
%
s"
%
(
id
)
query
=
f
"delete FROM txt_doc WHERE hash =
%
s"
%
item
self
.
db
.
execute
(
query
)
self
.
db
.
execute
(
query
)
def
search
(
self
,
id
):
def
search
(
self
,
item
):
query
=
"SELECT text,matadate FROM txt_doc WHERE hash =
%
s"
query
=
"SELECT text,matadate FROM txt_doc WHERE hash =
%
s"
self
.
db
.
execute_args
(
query
,
[
id
])
self
.
db
.
execute_args
(
query
,
[
item
])
answer
=
self
.
db
.
fetchall
()
answer
=
self
.
db
.
fetchall
()
if
len
(
answer
)
>
0
:
if
len
(
answer
)
>
0
:
return
answer
[
0
]
return
answer
[
0
]
...
@@ -60,4 +64,3 @@ class TxtDoc:
...
@@ -60,4 +64,3 @@ class TxtDoc:
query
=
"DROP TABLE txt_doc"
query
=
"DROP TABLE txt_doc"
self
.
db
.
format
(
query
)
self
.
db
.
format
(
query
)
print
(
"drop table txt_doc ok"
)
print
(
"drop table txt_doc ok"
)
src/pgdb/knowledge/vec_txt_table.py
View file @
493cdd59
from
.k_db
import
PostgresDB
from
.k_db
import
PostgresDB
TABLE_VEC_TXT
=
"""
TABLE_VEC_TXT
=
"""
CREATE TABLE vec_txt (
CREATE TABLE vec_txt (
vector_id varchar(36) PRIMARY KEY,
vector_id varchar(36) PRIMARY KEY,
...
@@ -6,7 +7,9 @@ CREATE TABLE vec_txt (
...
@@ -6,7 +7,9 @@ CREATE TABLE vec_txt (
paragraph_id varchar(40) not null
paragraph_id varchar(40) not null
)
)
"""
"""
#025a9bee-2eb2-47f5-9722-525e05a0442b
# 025a9bee-2eb2-47f5-9722-525e05a0442b
class
TxtVector
:
class
TxtVector
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
self
.
db
=
db
self
.
db
=
db
...
@@ -16,19 +19,21 @@ class TxtVector:
...
@@ -16,19 +19,21 @@ class TxtVector:
args
=
[]
args
=
[]
for
value
in
vectors
:
for
value
in
vectors
:
value
=
list
(
value
)
value
=
list
(
value
)
query
+=
"(
%
s,
%
s,
%
s),"
query
+=
"(
%
s,
%
s,
%
s),"
args
.
extend
(
value
)
args
.
extend
(
value
)
query
=
query
[:
len
(
query
)
-
1
]
query
=
query
[:
len
(
query
)
-
1
]
query
+=
f
"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
query
+=
f
"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
# query += ";"
# query += ";"
self
.
db
.
execute_args
(
query
,
args
)
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
id
in
ids
:
def
delete
(
self
,
ids
):
query
=
f
"delete FROM vec_txt WHERE vector_id = '
%
s'"
%
(
id
,)
for
item
in
ids
:
query
=
f
"delete FROM vec_txt WHERE vector_id = '
%
s'"
%
(
item
,)
self
.
db
.
execute
(
query
)
self
.
db
.
execute
(
query
)
def
search
(
self
,
search
:
str
):
def
search
(
self
,
search
:
str
):
query
=
f
"SELECT paragraph_id,text FROM vec_txt WHERE vector_id =
%
s"
query
=
f
"SELECT paragraph_id,text FROM vec_txt WHERE vector_id =
%
s"
self
.
db
.
execute_args
(
query
,[
search
])
self
.
db
.
execute_args
(
query
,
[
search
])
answer
=
self
.
db
.
fetchall
()
answer
=
self
.
db
.
fetchall
()
print
(
answer
)
print
(
answer
)
return
answer
[
0
]
return
answer
[
0
]
...
...
test/chat_table_test.py
View file @
493cdd59
import
sys
import
sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.pgdb.chat.chat_table
import
Chat
from
src.pgdb.chat.chat_table
import
Chat
from
src.pgdb.chat.c_user_table
import
CUser
from
src.pgdb.chat.c_user_table
import
CUser
from
src.pgdb.chat.turn_qa_table
import
TurnQa
from
src.pgdb.chat.turn_qa_table
import
TurnQa
"""测试会话相关数据可的连接"""
"""测试会话相关数据可的连接"""
c_db
=
UPostgresDB
(
host
=
"localhost"
,
database
=
"laechat"
,
user
=
"postgres"
,
password
=
"chenzl"
,
port
=
5432
)
chat
=
Chat
(
db
=
c_db
)
c_user
=
CUser
(
db
=
c_db
)
turn_qa
=
TurnQa
(
db
=
c_db
)
chat
.
create_table
()
c_user
.
create_table
()
turn_qa
.
create_table
()
# chat_id, user_id, info, deleted
def
test
():
chat
.
insert
([
"3333"
,
"1111"
,
"没有info"
,
0
])
c_db
=
UPostgresDB
(
host
=
"localhost"
,
database
=
"laechat"
,
user
=
"postgres"
,
password
=
"chenzl"
,
port
=
5432
)
chat
=
Chat
(
db
=
c_db
)
c_user
=
CUser
(
db
=
c_db
)
turn_qa
=
TurnQa
(
db
=
c_db
)
chat
.
create_table
()
c_user
.
create_table
()
turn_qa
.
create_table
()
# chat_id, user_id, info, deleted
chat
.
insert
([
"3333"
,
"1111"
,
"没有info"
,
0
])
# user_id, account, password
c_user
.
insert
([
"111"
,
"zhangsan"
,
"111111"
])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa
.
insert
([
"222"
,
"1111"
,
"nihao"
,
"nihao"
,
1
,
0
])
# user_id, account, password
c_user
.
insert
([
"111"
,
"zhangsan"
,
"111111"
])
# turn_id, chat_id, question, answer, turn_number, is_last
if
__name__
==
"main"
:
turn_qa
.
insert
([
"222"
,
"1111"
,
"nihao"
,
"nihao"
,
1
,
0
])
test
()
\ No newline at end of file
test/k_store_test.py
View file @
493cdd59
import
sys
import
sys
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
import
time
from
src.loader.load
import
loads_path
from
src.loader.load
import
loads_path
,
loads
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.config.consts
import
(
from
src.config.consts
import
(
VEC_DB_DBNAME
,
VEC_DB_DBNAME
,
...
@@ -18,24 +18,27 @@ from src.config.consts import (
...
@@ -18,24 +18,27 @@ from src.config.consts import (
from
src.loader.callback
import
BaseCallback
from
src.loader.callback
import
BaseCallback
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class
localCallback
(
BaseCallback
):
class
localCallback
(
BaseCallback
):
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
if
len
(
title
+
content
)
==
0
:
if
len
(
title
+
content
)
==
0
:
return
True
return
True
return
(
len
(
title
+
content
)
/
(
len
(
title
.
splitlines
())
+
len
(
content
.
splitlines
()))
<
20
)
or
"思考题"
in
title
return
(
len
(
title
+
content
)
/
(
len
(
title
.
splitlines
())
+
len
(
content
.
splitlines
()))
<
20
)
or
"思考题"
in
title
"""测试资料入库(pgsql和faiss)"""
"""测试资料入库(pgsql和faiss)"""
def
test_faiss_from_dir
():
def
test_faiss_from_dir
():
vecstore_faiss
=
VectorStore_FAISS
(
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
True
)
reset
=
True
)
docs
=
loads_path
(
KNOWLEDGE_PATH
,
mode
=
"paged"
,
sentence_size
=
512
,
callbacks
=
[
localCallback
()])
docs
=
loads_path
(
KNOWLEDGE_PATH
,
mode
=
"paged"
,
sentence_size
=
512
,
callbacks
=
[
localCallback
()])
print
(
len
(
docs
))
print
(
len
(
docs
))
last_doc
=
None
last_doc
=
None
docs1
=
[]
docs1
=
[]
...
@@ -45,7 +48,8 @@ def test_faiss_from_dir():
...
@@ -45,7 +48,8 @@ def test_faiss_from_dir():
continue
continue
if
"font-size"
not
in
doc
.
metadata
or
"page_number"
not
in
doc
.
metadata
:
if
"font-size"
not
in
doc
.
metadata
or
"page_number"
not
in
doc
.
metadata
:
continue
continue
if
doc
.
metadata
[
"font-size"
]
==
last_doc
.
metadata
[
"font-size"
]
and
doc
.
metadata
[
"page_number"
]
==
last_doc
.
metadata
[
"page_number"
]
and
len
(
doc
.
page_content
)
+
len
(
last_doc
.
page_content
)
<
512
/
4
*
3
:
if
doc
.
metadata
[
"font-size"
]
==
last_doc
.
metadata
[
"font-size"
]
and
doc
.
metadata
[
"page_number"
]
==
\
last_doc
.
metadata
[
"page_number"
]
and
len
(
doc
.
page_content
)
+
len
(
last_doc
.
page_content
)
<
512
/
4
*
3
:
last_doc
.
page_content
+=
doc
.
page_content
last_doc
.
page_content
+=
doc
.
page_content
else
:
else
:
docs1
.
append
(
last_doc
)
docs1
.
append
(
last_doc
)
...
@@ -56,17 +60,21 @@ def test_faiss_from_dir():
...
@@ -56,17 +60,21 @@ def test_faiss_from_dir():
print
(
len
(
docs
))
print
(
len
(
docs
))
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
for
i
in
range
(
0
,
len
(
docs
),
300
):
for
i
in
range
(
0
,
len
(
docs
),
300
):
vecstore_faiss
.
_add_documents
(
docs
[
i
:
i
+
300
if
i
+
300
<
len
(
docs
)
else
len
(
docs
)],
need_split
=
True
)
vecstore_faiss
.
_add_documents
(
docs
[
i
:
i
+
300
if
i
+
300
<
len
(
docs
)
else
len
(
docs
)],
need_split
=
True
)
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
vecstore_faiss
.
_save_local
()
vecstore_faiss
.
_save_local
()
"""测试faiss向量数据库查询结果"""
"""测试faiss向量数据库查询结果"""
def
test_faiss_load
():
def
test_faiss_load
():
vecstore_faiss
=
VectorStore_FAISS
(
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
reset
=
False
)
print
(
vecstore_faiss
.
_join_document
(
vecstore_faiss
.
get_text_similarity
(
"征信业务有什么情况"
)))
print
(
vecstore_faiss
.
_join_document
(
vecstore_faiss
.
get_text_similarity
(
"征信业务有什么情况"
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment