Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
493cdd59
Commit
493cdd59
authored
Apr 25, 2024
by
陈正乐
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
代码格式化
parent
9d8ee0af
Hide whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
654 additions
and
572 deletions
+654
-572
consts.py
src/config/consts.py
+9
-10
__init__.py
src/llm/__init__.py
+1
-2
baichuan.py
src/llm/baichuan.py
+13
-22
chatglm.py
src/llm/chatglm.py
+57
-55
chatglm_openapi.py
src/llm/chatglm_openapi.py
+14
-14
ernie.py
src/llm/ernie.py
+24
-31
ernie_sdk.py
src/llm/ernie_sdk.py
+14
-8
ernie_with_sdk.py
src/llm/ernie_with_sdk.py
+28
-29
loader.py
src/llm/loader.py
+14
-13
spark.py
src/llm/spark.py
+25
-32
wrapper.py
src/llm/wrapper.py
+8
-13
callback.py
src/loader/callback.py
+3
-3
chinese_text_splitter.py
src/loader/chinese_text_splitter.py
+3
-3
config.py
src/loader/config.py
+1
-2
load.py
src/loader/load.py
+131
-96
zh_title_enhance.py
src/loader/zh_title_enhance.py
+1
-1
c_db.py
src/pgdb/chat/c_db.py
+18
-17
c_user_table.py
src/pgdb/chat/c_user_table.py
+4
-3
chat_table.py
src/pgdb/chat/chat_table.py
+4
-3
turn_qa_table.py
src/pgdb/chat/turn_qa_table.py
+4
-3
callback.py
src/pgdb/knowledge/callback.py
+22
-19
k_db.py
src/pgdb/knowledge/k_db.py
+20
-17
pgsqldocstore.py
src/pgdb/knowledge/pgsqldocstore.py
+42
-34
similarity.py
src/pgdb/knowledge/similarity.py
+114
-82
txt_doc_table.py
src/pgdb/knowledge/txt_doc_table.py
+14
-11
vec_txt_table.py
src/pgdb/knowledge/vec_txt_table.py
+14
-10
chat_table_test.py
test/chat_table_test.py
+21
-15
k_store_test.py
test/k_store_test.py
+31
-24
No files found.
src/config/consts.py
View file @
493cdd59
...
...
@@ -2,19 +2,19 @@
# 资料存储数据库配置
# =============================
VEC_DB_HOST
=
'localhost'
VEC_DB_DBNAME
=
'lae'
VEC_DB_USER
=
'postgres'
VEC_DB_PASSWORD
=
'chenzl'
VEC_DB_PORT
=
'5432'
VEC_DB_DBNAME
=
'lae'
VEC_DB_USER
=
'postgres'
VEC_DB_PASSWORD
=
'chenzl'
VEC_DB_PORT
=
'5432'
# =============================
# 聊天相关数据库配置
# =============================
CHAT_DB_HOST
=
'localhost'
CHAT_DB_DBNAME
=
'laechat'
CHAT_DB_USER
=
'postgres'
CHAT_DB_PASSWORD
=
'chenzl'
CHAT_DB_PORT
=
'5432'
CHAT_DB_DBNAME
=
'laechat'
CHAT_DB_USER
=
'postgres'
CHAT_DB_PASSWORD
=
'chenzl'
CHAT_DB_PORT
=
'5432'
# =============================
# 向量化模型路径配置
...
...
@@ -41,4 +41,4 @@ INDEX_NAME = 'know'
# =============================
# 知识相关资料配置
# =============================
KNOWLEDGE_PATH
=
'C:
\\
Users
\\
15663
\\
Desktop
\\
work
\\
llm_gjjs
\\
兴火燎原知识库
\\
兴火燎原知识库
\\
law
\\
pdf'
\ No newline at end of file
KNOWLEDGE_PATH
=
'C:
\\
Users
\\
15663
\\
Desktop
\\
work
\\
llm_gjjs
\\
兴火燎原知识库
\\
兴火燎原知识库
\\
law
\\
pdf'
src/llm/__init__.py
View file @
493cdd59
"""各种大模型提供的服务"""
\ No newline at end of file
"""各种大模型提供的服务"""
src/llm/baichuan.py
View file @
493cdd59
import
os
from
typing
import
Dict
,
Optional
,
List
from
langchain.llms.base
import
BaseLLM
,
LLM
from
typing
import
Dict
,
Optional
,
List
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
,
AutoModelForCausalLM
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
,
AutoModelForCausalLM
from
transformers.generation.utils
import
GenerationConfig
from
pydantic
import
root_validator
class
BaichuanLLM
(
LLM
):
model_name
:
str
=
"baichuan-inc/Baichuan-13B-Chat"
quantization_bit
:
Optional
[
int
]
=
None
tokenizer
:
AutoTokenizer
=
None
model
:
AutoModel
=
None
def
_llm_type
(
self
)
->
str
:
return
"chatglm_local"
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
if
not
values
[
"model_name"
]:
raise
ValueError
(
"No model name provided."
)
model_name
=
values
[
"model_name"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
torch
.
float16
,
...
...
@@ -43,7 +40,7 @@ class BaichuanLLM(LLM):
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
.
cuda
()
else
:
model
=
model
.
half
()
.
cuda
()
model
=
model
.
half
()
.
cuda
()
model
=
model
.
eval
()
...
...
@@ -51,14 +48,9 @@ class BaichuanLLM(LLM):
values
[
"model"
]
=
model
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
message
=
[]
message
.
append
({
"role"
:
"user"
,
"content"
:
prompt
})
resp
=
self
.
model
.
chat
(
self
.
tokenizer
,
message
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
message
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
resp
=
self
.
model
.
chat
(
self
.
tokenizer
,
message
)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return
resp
\ No newline at end of file
return
resp
src/llm/chatglm.py
View file @
493cdd59
import
os
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
langchain
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
aiohttp
import
asyncio
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
...
...
@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM):
tokenizer
:
AutoTokenizer
=
None
model
:
AutoModel
=
None
def
_llm_type
(
self
)
->
str
:
return
"chatglm_local"
# @root_validator()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
@staticmethod
def
validate_environment
(
values
:
Dict
)
->
Dict
:
if
not
values
[
"model_name"
]:
raise
ValueError
(
"No model name provided."
)
model_name
=
values
[
"model_name"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if
values
[
"pre_seq_len"
]:
...
...
@@ -56,7 +57,7 @@ class ChatGLMLocLLM(LLM):
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
else
:
model
=
AutoModel
.
from_pretrained
(
model_name
,
config
=
config
,
trust_remote_code
=
True
)
.
half
()
.
cuda
()
if
values
[
"pre_seq_len"
]:
# P-tuning v2
model
=
model
.
half
()
.
cuda
()
...
...
@@ -64,7 +65,7 @@ class ChatGLMLocLLM(LLM):
if
values
[
"quantization_bit"
]:
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
model
=
model
.
eval
()
...
...
@@ -72,30 +73,26 @@ class ChatGLMLocLLM(LLM):
values
[
"model"
]
=
model
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return
resp
class
ChatGLMSerLLM
(
LLM
):
# 模型服务url
url
:
str
=
"http://127.0.0.1:8000"
chat_history
:
dict
=
[]
chat_history
:
dict
=
[]
out_stream
:
bool
=
False
cache
:
bool
=
False
@property
def
_llm_type
(
self
)
->
str
:
return
"chatglm3-6b"
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
resp
=
self
.
_post
(
url
=
self
.
url
+
"/tokens"
,
query
=
self
.
_construct_query
(
text
))
resp
=
self
.
_post
(
url
=
self
.
url
+
"/tokens"
,
query
=
self
.
_construct_query
(
text
))
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
predictions
=
resp_json
[
'response'
]
...
...
@@ -103,28 +100,29 @@ class ChatGLMSerLLM(LLM):
return
predictions
else
:
return
len
(
text
)
def
convert_data
(
self
,
data
):
@staticmethod
def
convert_data
(
data
):
result
=
[]
for
item
in
data
:
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
return
result
def
_construct_query
(
self
,
prompt
:
str
,
temperature
=
0.95
)
->
Dict
:
def
_construct_query
(
self
,
prompt
:
str
,
temperature
=
0.95
)
->
Dict
:
"""构造请求体
"""
# self.chat_history.append({"role": "user", "content": prompt})
query
=
{
"prompt"
:
prompt
,
"history"
:
self
.
chat_history
,
"history"
:
self
.
chat_history
,
"max_length"
:
4096
,
"top_p"
:
0.7
,
"temperature"
:
temperature
}
return
query
@classmethod
def
_post
(
self
,
url
:
str
,
def
_post
(
cls
,
url
:
str
,
query
:
Dict
)
->
Any
:
"""POST请求
"""
...
...
@@ -135,51 +133,55 @@ class ChatGLMSerLLM(LLM):
headers
=
_headers
,
timeout
=
300
)
return
resp
async
def
_post_stream
(
self
,
url
:
str
,
query
:
Dict
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
@staticmethod
async
def
_post_stream
(
url
:
str
,
query
:
Dict
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
"""POST请求
"""
_headers
=
{
"Content_Type"
:
"application/json"
}
async
with
aiohttp
.
ClientSession
()
as
sess
:
async
with
sess
.
post
(
url
,
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
as
response
:
async
with
sess
.
post
(
url
,
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
as
response
:
if
response
.
status
==
200
:
if
stream
and
not
run_manager
:
print
(
'not callable'
)
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_start
(
None
,
None
)
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
_callable
.
on_llm_start
(
None
,
None
)
async
for
chunk
in
response
.
content
.
iter_any
():
# 处理每个块的数据
if
chunk
and
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
# print(chunk.decode("utf-8"),end="")
await
callable
.
on_llm_new_token
(
chunk
.
decode
(
"utf-8"
))
await
_
callable
.
on_llm_new_token
(
chunk
.
decode
(
"utf-8"
))
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_end
(
None
)
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
_
callable
.
on_llm_end
(
None
)
else
:
raise
ValueError
(
f
'glm 请求异常,http code:{response.status}'
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
,
stream
=
False
,
**
kwargs
:
Any
)
->
str
:
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
# display("==============================")
# display(query)
# post
if
stream
or
self
.
out_stream
:
async
def
_post_stream
():
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
stream
or
self
.
out_stream
)
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
stream
or
self
.
out_stream
)
asyncio
.
run
(
_post_stream
())
return
''
else
:
resp
=
self
.
_post
(
url
=
self
.
url
,
query
=
query
)
query
=
query
)
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
...
...
@@ -189,18 +191,19 @@ class ChatGLMSerLLM(LLM):
return
predictions
else
:
raise
ValueError
(
f
'glm 请求异常,http code:{resp.status_code}'
)
async
def
_acall
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
self
.
out_stream
)
return
''
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
self
.
out_stream
)
return
''
@property
def
_identifying_params
(
self
)
->
Mapping
[
str
,
Any
]:
...
...
@@ -209,4 +212,4 @@ class ChatGLMSerLLM(LLM):
_param_dict
=
{
"url"
:
self
.
url
}
return
_param_dict
\ No newline at end of file
return
_param_dict
src/llm/chatglm_openapi.py
View file @
493cdd59
import
os
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
langchain
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain_openai
import
OpenAI
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
class
ChatGLMSerLLM
(
OpenAI
):
def
get_token_ids
(
self
,
text
:
str
)
->
List
[
int
]:
if
self
.
model_name
.
__contains__
(
"chatglm"
):
## 发起http请求,获取token_ids
url
=
f
"{self.openai_api_base}/num_tokens"
query
=
{
"prompt"
:
text
,
"model"
:
self
.
model_name
}
_headers
=
{
"Content_Type"
:
"application/json"
,
"Authorization"
:
"chatglm "
+
self
.
openai_api_key
}
resp
=
self
.
_post
(
url
=
url
,
query
=
query
,
headers
=
_headers
)
query
=
{
"prompt"
:
text
,
"model"
:
self
.
model_name
}
_headers
=
{
"Content_Type"
:
"application/json"
,
"Authorization"
:
"chatglm "
+
self
.
openai_api_key
}
resp
=
self
.
_post
(
url
=
url
,
query
=
query
,
headers
=
_headers
)
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
print
(
resp_json
)
...
...
@@ -30,10 +31,10 @@ class ChatGLMSerLLM(OpenAI):
## predictions字符串转int
return
[
int
(
predictions
)]
return
[
len
(
text
)]
@classmethod
def
_post
(
self
,
url
:
str
,
query
:
Dict
,
headers
:
Dict
)
->
Any
:
def
_post
(
cls
,
url
:
str
,
query
:
Dict
,
headers
:
Dict
)
->
Any
:
"""POST请求
"""
_headers
=
{
"Content_Type"
:
"application/json"
}
...
...
@@ -43,4 +44,4 @@ class ChatGLMSerLLM(OpenAI):
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
return
resp
\ No newline at end of file
return
resp
src/llm/ernie.py
View file @
493cdd59
...
...
@@ -2,7 +2,7 @@ import logging
import
os
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.schema
import
LLMResult
from
langchain.utils
import
get_from_dict_or_env
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
...
...
@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m
logger
=
logging
.
getLogger
(
__name__
)
class
ModelType
(
Enum
):
ERNIE
=
"ernie"
ERNIE_LITE
=
"ernie-lite"
...
...
@@ -25,7 +26,7 @@ class ModelType(Enum):
LLAMA2_13B
=
"llama2-13b"
LLAMA2_70B
=
"llama2-70b"
QFCN_LLAMA2_7B
=
"qfcn-llama2-7b"
BLOOMZ_7B
=
"bloomz-7b"
BLOOMZ_7B
=
"bloomz-7b"
MODEL_SERVICE_BASE_URL
=
"https://aip.baidubce.com/rpc/2.0/"
...
...
@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = {
ModelType
.
BLOOMZ_7B
:
"ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
,
}
class
ErnieLLM
(
LLM
):
"""
ErnieLLM is a LLM that uses Ernie to generate text.
...
...
@@ -52,27 +54,23 @@ class ErnieLLM(LLM):
access_token
:
Optional
[
str
]
=
""
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
# print(values)
model_name
=
ModelType
(
get_from_dict_or_env
(
values
,
"model_name"
,
"model_name"
,
str
(
ModelType
.
ERNIE
)))
access_token
=
get_from_dict_or_env
(
values
,
"access_token"
,
"ERNIE_ACCESS_TOKEN"
,
""
)
if
not
access_token
:
raise
ValueError
(
"No access token provided."
)
values
[
"model_name"
]
=
model_name
values
[
"access_token"
]
=
access_token
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
request
=
CompletionRequest
(
messages
=
[
Message
(
"user"
,
prompt
)])
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
request
=
CompletionRequest
(
messages
=
[
Message
(
"user"
,
prompt
)])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
or
""
,
request
)
try
:
# 你的代码
...
...
@@ -81,9 +79,8 @@ class ErnieLLM(LLM):
return
response
except
Exception
as
e
:
# 处理异常
print
(
"exception:"
,
e
)
print
(
"exception:"
,
e
)
return
e
.
__str__
()
@property
def
_llm_type
(
self
)
->
str
:
...
...
@@ -94,28 +91,25 @@ class ErnieLLM(LLM):
# return {
# "name": "ernie",
# }
def
_get_model_service_url
(
model_name
)
->
str
:
# print("_get_model_service_url model_name: ",model_name)
return
MODEL_SERVICE_BASE_URL
+
MODEL_SERVICE_Suffix
[
model_name
]
return
MODEL_SERVICE_BASE_URL
+
MODEL_SERVICE_Suffix
[
model_name
]
class
ErnieChat
(
LLM
):
model_name
:
ModelType
access_token
:
str
access_token
:
str
prefix_messages
:
List
=
Field
(
default_factory
=
list
)
id
:
str
=
""
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
msg
=
user_message
(
prompt
)
request
=
CompletionRequest
(
messages
=
self
.
prefix_messages
+
[
msg
])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
,
request
)
request
=
CompletionRequest
(
messages
=
self
.
prefix_messages
+
[
msg
])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
,
request
)
try
:
# 你的代码
response
=
bot
.
get_response
()
.
result
...
...
@@ -127,11 +121,11 @@ class ErnieChat(LLM):
except
Exception
as
e
:
# 处理异常
raise
e
def
_get_id
(
self
)
->
str
:
return
self
.
id
@property
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
return
"ernie"
\ No newline at end of file
return
"ernie"
src/llm/ernie_sdk.py
View file @
493cdd59
from
dataclasses
import
asdict
,
dataclass
from
typing
import
List
...
...
@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field
from
enum
import
Enum
class
MessageRole
(
str
,
Enum
):
USER
=
"user"
BOT
=
"assistant"
@dataclass
class
Message
:
role
:
str
content
:
str
@dataclass
class
CompletionRequest
:
messages
:
List
[
Message
]
stream
:
bool
=
False
user
:
str
=
""
@dataclass
class
Usage
:
prompt_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
@dataclass
class
CompletionResponse
:
id
:
str
...
...
@@ -42,12 +45,14 @@ class CompletionResponse:
is_safe
:
bool
=
False
is_truncated
:
bool
=
False
class
ErrorResponse
(
BaseModel
):
error_code
:
int
=
Field
(
...
)
error_msg
:
str
=
Field
(
...
)
id
:
str
=
Field
(
...
)
class
ErnieBot
():
class
ErnieBot
:
url
:
str
access_token
:
str
request
:
CompletionRequest
...
...
@@ -64,17 +69,19 @@ class ErnieBot():
headers
=
{
'Content-Type'
:
'application/json'
}
params
=
{
'access_token'
:
self
.
access_token
}
request_dict
=
asdict
(
self
.
request
)
response
=
requests
.
post
(
self
.
url
,
params
=
params
,
data
=
json
.
dumps
(
request_dict
),
headers
=
headers
)
request_dict
=
asdict
(
self
.
request
)
response
=
requests
.
post
(
self
.
url
,
params
=
params
,
data
=
json
.
dumps
(
request_dict
),
headers
=
headers
)
# print(response.json())
try
:
return
CompletionResponse
(
**
response
.
json
())
except
Exception
as
e
:
print
(
e
)
raise
Exception
(
response
.
json
())
def
user_message
(
prompt
:
str
)
->
Message
:
return
Message
(
MessageRole
.
USER
,
prompt
)
def
bot_message
(
prompt
:
str
)
->
Message
:
return
Message
(
MessageRole
.
BOT
,
prompt
)
\ No newline at end of file
return
Message
(
MessageRole
.
BOT
,
prompt
)
src/llm/ernie_with_sdk.py
View file @
493cdd59
import
os
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
qianfan
from
qianfan
import
ChatCompletion
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class
ChatERNIESerLLM
(
LLM
):
# 模型服务url
chat_completion
:
ChatCompletion
=
None
chat_completion
:
ChatCompletion
=
None
# url: str = "http://127.0.0.1:8000"
chat_history
:
dict
=
[]
chat_history
:
dict
=
[]
out_stream
:
bool
=
False
cache
:
bool
=
False
model_name
:
str
=
"ERNIE-Bot"
model_name
:
str
=
"ERNIE-Bot"
# def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
@property
def
_llm_type
(
self
)
->
str
:
return
self
.
model_name
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
return
len
(
text
)
def
convert_data
(
self
,
data
):
@staticmethod
def
convert_data
(
data
):
result
=
[]
for
item
in
data
:
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
return
result
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
,
stream
=
False
,
**
kwargs
:
Any
)
->
str
:
resp
=
self
.
chat_completion
.
do
(
model
=
self
.
model_name
,
messages
=
[{
resp
=
self
.
chat_completion
.
do
(
model
=
self
.
model_name
,
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt
}])
...
...
@@ -54,26 +54,26 @@ class ChatERNIESerLLM(LLM):
return
resp
.
body
[
"result"
]
async
def
_post_stream
(
self
,
query
:
Dict
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
query
:
Dict
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
"""POST请求
"""
async
for
r
in
await
self
.
chat_completion
.
ado
(
model
=
self
.
model_name
,
messages
=
[
query
],
stream
=
stream
):
async
for
r
in
await
self
.
chat_completion
.
ado
(
model
=
self
.
model_name
,
messages
=
[
query
],
stream
=
stream
):
assert
r
.
code
==
200
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_new_token
(
r
.
body
[
"result"
])
for
_callable
in
run_manager
.
get_sync
()
.
handlers
:
await
_callable
.
on_llm_new_token
(
r
.
body
[
"result"
])
async
def
_acall
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
await
self
.
_post_stream
(
query
=
{
"role"
:
"user"
,
"content"
:
prompt
},
stream
=
True
,
run_manager
=
run_manager
)
"role"
:
"user"
,
"content"
:
prompt
},
stream
=
True
,
run_manager
=
run_manager
)
return
''
\ No newline at end of file
src/llm/loader.py
View file @
493cdd59
...
...
@@ -4,6 +4,7 @@ import torch
from
transformers
import
AutoModel
,
AutoTokenizer
,
AutoConfig
,
DataCollatorForSeq2Seq
from
peft
import
PeftModel
class
ModelLoader
:
def
__init__
(
self
,
model_name_or_path
,
pre_seq_len
=
0
,
prefix_projection
=
False
):
self
.
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
...
...
@@ -23,18 +24,19 @@ class ModelLoader:
def
models
(
self
):
return
self
.
model
,
self
.
tokenizer
def
collator
(
self
):
return
DataCollatorForSeq2Seq
(
tokenizer
=
self
.
tokenizer
,
model
=
self
.
model
)
def
load_lora
(
self
,
ckpt_path
,
name
=
"default"
):
#训练时节约GPU占用
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
ckpt_path
,
adapter_name
=
name
)
self
.
model
=
peft_loaded
.
merge_and_unload
()
def
load_lora
(
self
,
ckpt_path
,
name
=
"default"
):
#
训练时节约GPU占用
_peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
ckpt_path
,
adapter_name
=
name
)
self
.
model
=
_
peft_loaded
.
merge_and_unload
()
print
(
f
"Load LoRA model successfully!"
)
def
load_loras
(
self
,
ckpt_paths
,
name
=
"default"
):
if
len
(
ckpt_paths
)
==
0
:
def
load_loras
(
self
,
ckpt_paths
,
name
=
"default"
):
global
peft_loaded
if
len
(
ckpt_paths
)
==
0
:
return
first
=
True
for
name
,
path
in
ckpt_paths
.
items
():
...
...
@@ -42,12 +44,12 @@ class ModelLoader:
if
first
:
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
path
,
adapter_name
=
name
)
first
=
False
else
:
peft_loaded
.
load_adapter
(
path
,
adapter_name
=
name
)
else
:
peft_loaded
.
load_adapter
(
path
,
adapter_name
=
name
)
peft_loaded
.
set_adapter
(
name
)
self
.
model
=
peft_loaded
def
load_prefix
(
self
,
ckpt_path
):
def
load_prefix
(
self
,
ckpt_path
):
prefix_state_dict
=
torch
.
load
(
os
.
path
.
join
(
ckpt_path
,
"pytorch_model.bin"
))
new_prefix_state_dict
=
{}
for
k
,
v
in
prefix_state_dict
.
items
():
...
...
@@ -56,4 +58,3 @@ class ModelLoader:
self
.
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
self
.
model
.
transformer
.
prefix_encoder
.
float
()
print
(
f
"Load prefix model successfully!"
)
src/llm/spark.py
View file @
493cdd59
...
...
@@ -2,7 +2,7 @@ import logging
import
os
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.schema
import
LLMResult
from
langchain.utils
import
get_from_dict_or_env
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
...
...
@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI
logger
=
logging
.
getLogger
(
__name__
)
text
=
[]
text
=
[]
# length = 0
def
getText
(
role
,
content
):
jsoncon
=
{}
jsoncon
[
"role"
]
=
role
jsoncon
[
"content"
]
=
content
def
getText
(
role
,
content
):
jsoncon
=
{
"role"
:
role
,
"content"
:
content
}
text
.
append
(
jsoncon
)
return
text
def
getlength
(
text
):
length
=
0
for
content
in
text
:
...
...
@@ -36,11 +35,12 @@ def getlength(text):
length
+=
leng
return
length
def
checklen
(
text
):
while
(
getlength
(
text
)
>
8000
):
del
text
[
0
]
return
text
def
checklen
(
_text
):
while
getlength
(
_text
)
>
8000
:
del
_text
[
0
]
return
_text
class
SparkLLM
(
LLM
):
"""
...
...
@@ -62,16 +62,16 @@ class SparkLLM(LLM):
None
,
description
=
"version"
,
)
api
:
SparkAPI
=
Field
(
api
:
SparkAPI
=
Field
(
None
,
description
=
"api"
,
)
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
# print(values)
appid
=
get_from_dict_or_env
(
values
,
"appid"
,
"XH_APPID"
,
""
)
api_key
=
get_from_dict_or_env
(
values
,
"api_key"
,
"XH_API_KEY"
,
""
)
api_secret
=
get_from_dict_or_env
(
values
,
"api_secret"
,
"XH_API_SECRET"
,
""
)
...
...
@@ -84,23 +84,19 @@ class SparkLLM(LLM):
raise
ValueError
(
"No api_key provided."
)
if
not
api_secret
:
raise
ValueError
(
"No api_secret provided."
)
values
[
"appid"
]
=
appid
values
[
"api_key"
]
=
api_key
values
[
"api_secret"
]
=
api_secret
api
=
SparkAPI
(
appid
,
api_key
,
api_secret
,
version
)
values
[
"api"
]
=
api
api
=
SparkAPI
(
appid
,
api_key
,
api_secret
,
version
)
values
[
"api"
]
=
api
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
question
=
self
.
getText
(
"user"
,
prompt
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
question
=
self
.
getText
(
"user"
,
prompt
)
try
:
# 你的代码
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
...
...
@@ -109,20 +105,18 @@ class SparkLLM(LLM):
return
response
except
Exception
as
e
:
# 处理异常
print
(
"exception:"
,
e
)
print
(
"exception:"
,
e
)
raise
e
def
getText
(
self
,
role
,
content
):
def
getText
(
self
,
role
,
content
):
text
=
[]
jsoncon
=
{}
jsoncon
[
"role"
]
=
role
jsoncon
[
"content"
]
=
content
text
.
append
(
jsoncon
)
return
text
return
text
@property
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
return
"xinghuo"
\ No newline at end of file
src/llm/wrapper.py
View file @
493cdd59
from
langchain.llms.base
import
LLM
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
pydantic
import
root_validator
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
class
WrapperLLM
(
LLM
):
tokenizer
:
PreTrainedTokenizer
=
None
model
:
PreTrainedModel
=
None
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
# print(values)
if
values
.
get
(
"model"
)
is
None
:
...
...
@@ -19,16 +19,12 @@ class WrapperLLM(LLM):
raise
ValueError
(
"No tokenizer provided."
)
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
return
resp
@property
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
return
"wrapper"
\ No newline at end of file
return
"wrapper"
src/loader/callback.py
View file @
493cdd59
from
abc
import
ABC
,
abstractmethod
class
BaseCallback
(
ABC
):
@abstractmethod
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
#return True舍弃当前段落
pass
\ No newline at end of file
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
# return True舍弃当前段落
pass
src/loader/chinese_text_splitter.py
View file @
493cdd59
...
...
@@ -25,7 +25,7 @@ class ChineseTextSplitter(CharacterTextSplitter):
sent_list
.
append
(
ele
)
return
sent_list
def
split_text
(
self
,
text
:
str
)
->
List
[
str
]:
##此处需要进一步优化逻辑
def
split_text
(
self
,
text
:
str
)
->
List
[
str
]:
##此处需要进一步优化逻辑
if
self
.
pdf
:
text
=
re
.
sub
(
r"\n{3,}"
,
r"\n"
,
text
)
text
=
re
.
sub
(
'
\
s'
,
" "
,
text
)
...
...
@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter):
ele_id
=
ele1_ls
.
index
(
ele_ele1
)
ele1_ls
=
ele1_ls
[:
ele_id
]
+
[
i
for
i
in
ele2_ls
if
i
]
+
ele1_ls
[
ele_id
+
1
:]
id
=
ls
.
index
(
ele
)
ls
=
ls
[:
id
]
+
[
i
for
i
in
ele1_ls
if
i
]
+
ls
[
id
+
1
:]
_
id
=
ls
.
index
(
ele
)
ls
=
ls
[:
_id
]
+
[
i
for
i
in
ele1_ls
if
i
]
+
ls
[
_
id
+
1
:]
return
ls
src/loader/config.py
View file @
493cdd59
# 文本分句长度
SENTENCE_SIZE
=
100
ZH_TITLE_ENHANCE
=
False
\ No newline at end of file
ZH_TITLE_ENHANCE
=
False
src/loader/load.py
View file @
493cdd59
import
os
,
copy
import
os
,
copy
from
langchain.document_loaders
import
UnstructuredFileLoader
,
TextLoader
,
CSVLoader
,
UnstructuredPDFLoader
,
UnstructuredWordDocumentLoader
,
PDFMinerPDFasHTMLLoader
from
langchain.document_loaders
import
UnstructuredFileLoader
,
TextLoader
,
CSVLoader
,
UnstructuredPDFLoader
,
\
UnstructuredWordDocumentLoader
,
PDFMinerPDFasHTMLLoader
from
.config
import
SENTENCE_SIZE
,
ZH_TITLE_ENHANCE
from
.chinese_text_splitter
import
ChineseTextSplitter
from
.zh_title_enhance
import
zh_title_enhance
from
langchain.schema
import
Document
from
typing
import
List
,
Dict
,
Optional
from
typing
import
List
,
Dict
,
Optional
from
src.loader.callback
import
BaseCallback
import
re
from
bs4
import
BeautifulSoup
def
load
(
filepath
,
mode
:
str
=
None
,
sentence_size
:
int
=
0
,
metadata
=
None
,
callbacks
=
None
,
**
kwargs
):
def
load
(
filepath
,
mode
:
str
=
None
,
sentence_size
:
int
=
0
,
metadata
=
None
,
callbacks
=
None
,
**
kwargs
):
r"""
加载文档,参数说明
mode:文档切割方式,"single", "elements", "paged"
...
...
@@ -19,37 +21,44 @@ def load(filepath,mode:str = None,sentence_size:int = 0,metadata = None,callback
kwargs
"""
if
filepath
.
lower
()
.
endswith
(
".md"
):
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
elif
filepath
.
lower
()
.
endswith
(
".txt"
):
loader
=
TextLoader
(
filepath
,
autodetect_encoding
=
True
,
**
kwargs
)
loader
=
TextLoader
(
filepath
,
autodetect_encoding
=
True
,
**
kwargs
)
elif
filepath
.
lower
()
.
endswith
(
".csv"
):
loader
=
CSVLoader
(
filepath
,
**
kwargs
)
loader
=
CSVLoader
(
filepath
,
**
kwargs
)
elif
filepath
.
lower
()
.
endswith
(
".pdf"
):
# loader = UnstructuredPDFLoader(filepath, mode=mode or "elements",**kwargs)
# 使用自定义pdf loader
return
__pdf_loader
(
filepath
,
sentence_size
=
sentence_size
,
metadata
=
metadata
,
callbacks
=
callbacks
)
return
__pdf_loader
(
filepath
,
sentence_size
=
sentence_size
,
metadata
=
metadata
,
callbacks
=
callbacks
)
elif
filepath
.
lower
()
.
endswith
(
".docx"
)
or
filepath
.
lower
()
.
endswith
(
".doc"
):
loader
=
UnstructuredWordDocumentLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
else
:
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
if
sentence_size
>
0
:
return
split
(
loader
.
load
(),
sentence_size
)
return
split
(
loader
.
load
(),
sentence_size
)
return
loader
.
load
()
def
loads_path
(
path
:
str
,
**
kwargs
):
return
loads
(
get_files_in_directory
(
path
),
**
kwargs
)
def
loads
(
filepaths
,
**
kwargs
):
default_kwargs
=
{
"mode"
:
"paged"
}
def
loads_path
(
path
:
str
,
**
kwargs
):
return
loads
(
get_files_in_directory
(
path
),
**
kwargs
)
def
loads
(
filepaths
,
**
kwargs
):
default_kwargs
=
{
"mode"
:
"paged"
}
default_kwargs
.
update
(
**
kwargs
)
documents
=
[
load
(
filepath
=
file
,
**
default_kwargs
)
for
file
in
filepaths
]
return
[
item
for
sublist
in
documents
for
item
in
sublist
]
def
append
(
documents
:
List
[
Document
]
=
[],
sentence_size
:
int
=
SENTENCE_SIZE
):
#保留文档结构信息,注意处理hash
def
append
(
documents
=
None
,
sentence_size
:
int
=
SENTENCE_SIZE
):
# 保留文档结构信息,注意处理hash
if
documents
is
None
:
documents
=
[]
effect_documents
=
[]
last_doc
=
documents
[
0
]
for
doc
in
documents
[
1
:]:
last_hash
=
""
if
"next_hash"
not
in
last_doc
.
metadata
else
last_doc
.
metadata
[
"next_hash"
]
doc_hash
=
""
if
"next_hash"
not
in
doc
.
metadata
else
doc
.
metadata
[
"next_hash"
]
if
len
(
last_doc
.
page_content
)
+
len
(
doc
.
page_content
)
<=
sentence_size
and
last_hash
==
doc_hash
:
if
len
(
last_doc
.
page_content
)
+
len
(
doc
.
page_content
)
<=
sentence_size
and
last_hash
==
doc_hash
:
last_doc
.
page_content
=
last_doc
.
page_content
+
doc
.
page_content
continue
else
:
...
...
@@ -58,28 +67,31 @@ def append(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE):#保
effect_documents
.
append
(
last_doc
)
return
effect_documents
def
split
(
documents
:
List
[
Document
]
=
[],
sentence_size
:
int
=
SENTENCE_SIZE
):
#保留文档结构信息,注意处理hash
def
split
(
documents
=
None
,
sentence_size
:
int
=
SENTENCE_SIZE
):
# 保留文档结构信息,注意处理hash
if
documents
is
None
:
documents
=
[]
effect_documents
=
[]
for
doc
in
documents
:
if
len
(
doc
.
page_content
)
>
sentence_size
:
words_list
=
re
.
split
(
r'·-·'
,
doc
.
page_content
.
replace
(
"。"
,
"。·-·"
)
.
replace
(
"
\n
"
,
"
\n
·-·"
))
#
插入分隔符,分割
document
=
Document
(
page_content
=
""
,
metadata
=
copy
.
deepcopy
(
doc
.
metadata
))
words_list
=
re
.
split
(
r'·-·'
,
doc
.
page_content
.
replace
(
"。"
,
"。·-·"
)
.
replace
(
"
\n
"
,
"
\n
·-·"
))
#
插入分隔符,分割
document
=
Document
(
page_content
=
""
,
metadata
=
copy
.
deepcopy
(
doc
.
metadata
))
first
=
True
for
word
in
words_list
:
if
len
(
document
.
page_content
)
+
len
(
word
)
<
sentence_size
:
document
.
page_content
+=
word
else
:
if
len
(
document
.
page_content
.
replace
(
" "
,
""
)
.
replace
(
"
\n
"
,
""
))
>
0
:
if
len
(
document
.
page_content
.
replace
(
" "
,
""
)
.
replace
(
"
\n
"
,
""
))
>
0
:
if
first
:
first
=
False
first
=
False
else
:
effect_documents
[
-
1
]
.
metadata
[
"next_doc"
]
=
document
.
page_content
effect_documents
.
append
(
document
)
document
=
Document
(
page_content
=
word
,
metadata
=
copy
.
deepcopy
(
doc
.
metadata
))
if
len
(
document
.
page_content
.
replace
(
" "
,
""
)
.
replace
(
"
\n
"
,
""
))
>
0
:
document
=
Document
(
page_content
=
word
,
metadata
=
copy
.
deepcopy
(
doc
.
metadata
))
if
len
(
document
.
page_content
.
replace
(
" "
,
""
)
.
replace
(
"
\n
"
,
""
))
>
0
:
if
first
:
first
=
False
pass
else
:
effect_documents
[
-
1
]
.
metadata
[
"next_doc"
]
=
document
.
page_content
effect_documents
.
append
(
document
)
...
...
@@ -87,10 +99,12 @@ def split(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE): #保
effect_documents
.
append
(
doc
)
return
effect_documents
def
load_file
(
filepath
,
sentence_size
=
SENTENCE_SIZE
,
using_zh_title_enhance
=
ZH_TITLE_ENHANCE
,
mode
:
str
=
None
,
**
kwargs
):
def
load_file
(
filepath
,
sentence_size
=
SENTENCE_SIZE
,
using_zh_title_enhance
=
ZH_TITLE_ENHANCE
,
mode
:
str
=
None
,
**
kwargs
):
print
(
"load_file"
,
filepath
)
if
filepath
.
lower
()
.
endswith
(
".md"
):
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
docs
=
loader
.
load
()
elif
filepath
.
lower
()
.
endswith
(
".txt"
):
loader
=
TextLoader
(
filepath
,
autodetect_encoding
=
True
,
**
kwargs
)
...
...
@@ -100,15 +114,15 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
loader
=
CSVLoader
(
filepath
,
**
kwargs
)
docs
=
loader
.
load
()
elif
filepath
.
lower
()
.
endswith
(
".pdf"
):
loader
=
UnstructuredPDFLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredPDFLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
True
,
sentence_size
=
sentence_size
)
docs
=
loader
.
load_and_split
(
textsplitter
)
elif
filepath
.
lower
()
.
endswith
(
".docx"
):
loader
=
UnstructuredWordDocumentLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredWordDocumentLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
False
,
sentence_size
=
sentence_size
)
docs
=
loader
.
load_and_split
(
textsplitter
)
else
:
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
loader
=
UnstructuredFileLoader
(
filepath
,
mode
=
mode
or
"elements"
,
**
kwargs
)
textsplitter
=
ChineseTextSplitter
(
pdf
=
False
,
sentence_size
=
sentence_size
)
docs
=
loader
.
load_and_split
(
text_splitter
=
textsplitter
)
if
using_zh_title_enhance
:
...
...
@@ -116,6 +130,7 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
write_check_file
(
filepath
,
docs
)
return
docs
def
write_check_file
(
filepath
,
docs
):
folder_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
filepath
),
"tmp_files"
)
if
not
os
.
path
.
exists
(
folder_path
):
...
...
@@ -128,7 +143,8 @@ def write_check_file(filepath, docs):
fout
.
write
(
str
(
i
))
fout
.
write
(
'
\n
'
)
fout
.
close
()
def
get_files_in_directory
(
directory
):
file_paths
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
directory
):
...
...
@@ -137,21 +153,29 @@ def get_files_in_directory(directory):
file_paths
.
append
(
file_path
)
return
file_paths
#自定义pdf load部分
def
__checkV
(
strings
:
str
):
# 自定义pdf load部分
def
__checkV
(
strings
:
str
):
lines
=
len
(
strings
.
splitlines
())
if
(
lines
>
3
and
len
(
strings
.
replace
(
" "
,
""
))
/
lines
<
15
)
:
if
lines
>
3
and
len
(
strings
.
replace
(
" "
,
""
))
/
lines
<
15
:
return
False
return
True
def
__isTitle
(
strings
:
str
):
return
len
(
strings
.
splitlines
())
==
1
and
len
(
strings
)
>
0
and
strings
.
endswith
(
"
\n
"
)
def
__appendPara
(
strings
:
str
):
return
strings
.
replace
(
".
\n
"
,
"^_^"
)
.
replace
(
"。
\n
"
,
"^-^"
)
.
replace
(
"?
\n
"
,
"?^-^"
)
.
replace
(
"?
\n
"
,
"?^-^"
)
.
replace
(
"
\n
"
,
""
)
.
replace
(
"^_^"
,
".
\n
"
)
.
replace
(
"^-^"
,
"。
\n
"
)
.
replace
(
"?^-^"
,
"?
\n
"
)
.
replace
(
"?^-^"
,
"?
\n
"
)
def
__check_fs_ff
(
line_ff_fs_s
,
fs
,
ff
):
#若当前行有上一行一样的字体、字号文字,则返回相同的。默认返回最长文本的字体和字号
def
__isTitle
(
strings
:
str
):
return
len
(
strings
.
splitlines
())
==
1
and
len
(
strings
)
>
0
and
strings
.
endswith
(
"
\n
"
)
def
__appendPara
(
strings
:
str
):
return
strings
.
replace
(
".
\n
"
,
"^_^"
)
.
replace
(
"。
\n
"
,
"^-^"
)
.
replace
(
"?
\n
"
,
"?^-^"
)
.
replace
(
"?
\n
"
,
"?^-^"
)
.
replace
(
"
\n
"
,
""
)
.
replace
(
"^_^"
,
".
\n
"
)
.
replace
(
"^-^"
,
"。
\n
"
)
.
replace
(
"?^-^"
,
"?
\n
"
)
.
replace
(
"?^-^"
,
"?
\n
"
)
def
__check_fs_ff
(
line_ff_fs_s
,
fs
,
ff
):
# 若当前行有上一行一样的字体、字号文字,则返回相同的。默认返回最长文本的字体和字号
re_fs
=
line_ff_fs_s
[
-
1
][
0
][
-
1
]
re_ff
=
line_ff_fs_s
[
-
1
][
1
][
-
1
]
if
line_ff_fs_s
[
-
1
][
1
]
else
None
max_len
=
0
for
ff_fs
in
line_ff_fs_s
:
#
寻找最长文本字体和字号
for
ff_fs
in
line_ff_fs_s
:
#
寻找最长文本字体和字号
c_max
=
max
(
list
(
map
(
int
,
ff_fs
[
0
])))
if
max_len
<
ff_fs
[
2
]
or
(
max_len
==
ff_fs
[
2
]
and
c_max
>
int
(
re_fs
)):
max_len
=
ff_fs
[
2
]
...
...
@@ -163,123 +187,133 @@ def __check_fs_ff(line_ff_fs_s,fs,ff): #若当前行有上一行一样的字
re_fs
=
fs
re_ff
=
ff
break
return
int
(
re_fs
),
re_ff
return
int
(
re_fs
),
re_ff
def
append_document
(
snippets1
:
List
[
Document
],
title
:
str
,
content
:
str
,
callbacks
,
font_size
,
page_num
,
metadate
,
need_append
:
bool
=
False
):
def
append_document
(
snippets1
:
List
[
Document
],
title
:
str
,
content
:
str
,
callbacks
,
font_size
,
page_num
,
metadate
,
need_append
:
bool
=
False
):
if
callbacks
:
for
cb
in
callbacks
:
if
isinstance
(
cb
,
BaseCallback
):
if
cb
.
filter
(
title
,
content
):
if
isinstance
(
cb
,
BaseCallback
):
if
cb
.
filter
(
title
,
content
):
return
if
need_append
and
len
(
snippets1
)
>
0
:
if
need_append
and
len
(
snippets1
)
>
0
:
ps
=
snippets1
.
pop
()
snippets1
.
append
(
Document
(
page_content
=
ps
.
page_content
+
title
,
metadata
=
ps
.
metadata
))
snippets1
.
append
(
Document
(
page_content
=
ps
.
page_content
+
title
,
metadata
=
ps
.
metadata
))
else
:
doc_metadata
=
{
"font-size"
:
font_size
,
"page_number"
:
page_num
}
doc_metadata
=
{
"font-size"
:
font_size
,
"page_number"
:
page_num
}
doc_metadata
.
update
(
metadate
)
snippets1
.
append
(
Document
(
page_content
=
title
+
content
,
metadata
=
doc_metadata
))
snippets1
.
append
(
Document
(
page_content
=
title
+
content
,
metadata
=
doc_metadata
))
'''
提取pdf文档,按标题和内容进行分割,文档的页码按标题所在页码为准
分割后的文本按sentence_size值再次分割,分割的文本的页码均属于父文本的页码
'''
def
__pdf_loader
(
filepath
:
str
,
sentence_size
:
int
=
0
,
metadata
=
None
,
callbacks
=
None
):
def
__pdf_loader
(
filepath
:
str
,
sentence_size
:
int
=
0
,
metadata
=
None
,
callbacks
=
None
):
if
not
filepath
.
lower
()
.
endswith
(
".pdf"
):
raise
ValueError
(
"file is not pdf document"
)
loader
=
PDFMinerPDFasHTMLLoader
(
filepath
)
documents
=
loader
.
load
()
soup
=
BeautifulSoup
(
documents
[
0
]
.
page_content
,
'html.parser'
)
soup
=
BeautifulSoup
(
documents
[
0
]
.
page_content
,
'html.parser'
)
content
=
soup
.
find_all
(
'div'
)
cur_fs
=
None
#
当前文本font-size
last_fs
=
None
#
上一段文本font-size
cur_ff
=
None
#
当前文本风格
cur_fs
=
None
#
当前文本font-size
last_fs
=
None
#
上一段文本font-size
cur_ff
=
None
#
当前文本风格
cur_text
=
''
fs_increasing
=
False
#
下一行字体变大,判断为标题,从此处分割
fs_increasing
=
False
#
下一行字体变大,判断为标题,从此处分割
last_text
=
''
last_page_num
=
1
#
上一页页码 根据page_split判断当前文本页码
page_num
=
1
#
初始页码
page_change
=
False
#
页面切换
page_split
=
False
#
页面是否出现文本分割
last_is_title
=
False
#
上一个文本是否是标题
snippets
:
List
[
Document
]
=
[]
last_page_num
=
1
#
上一页页码 根据page_split判断当前文本页码
page_num
=
1
#
初始页码
page_change
=
False
#
页面切换
page_split
=
False
#
页面是否出现文本分割
last_is_title
=
False
#
上一个文本是否是标题
snippets
:
List
[
Document
]
=
[]
filename
=
os
.
path
.
basename
(
filepath
)
if
metadata
:
metadata
.
update
({
'source'
:
filepath
,
'filename'
:
filename
,
'filetype'
:
'application/pdf'
})
metadata
.
update
({
'source'
:
filepath
,
'filename'
:
filename
,
'filetype'
:
'application/pdf'
})
else
:
metadata
=
{
'source'
:
filepath
,
'filename'
:
filename
,
'filetype'
:
'application/pdf'
}
metadata
=
{
'source'
:
filepath
,
'filename'
:
filename
,
'filetype'
:
'application/pdf'
}
for
c
in
content
:
divs
=
c
.
get
(
'style'
)
if
re
.
match
(
r"^(Page|page)"
,
c
.
text
):
#
检测当前页的页码
match
=
re
.
match
(
r"^(page|Page)\s+(\d+)"
,
c
.
text
)
if
re
.
match
(
r"^(Page|page)"
,
c
.
text
):
#
检测当前页的页码
match
=
re
.
match
(
r"^(page|Page)\s+(\d+)"
,
c
.
text
)
if
match
:
if
page_split
:
#
如果有文本分割,则换页,没有则保持当前文本起始页码
if
page_split
:
#
如果有文本分割,则换页,没有则保持当前文本起始页码
last_page_num
=
page_num
page_num
=
match
.
group
(
2
)
if
len
(
last_text
)
+
len
(
cur_text
)
==
0
:
#
如果翻页且文本为空,上一页页码为当前页码
if
len
(
last_text
)
+
len
(
cur_text
)
==
0
:
#
如果翻页且文本为空,上一页页码为当前页码
last_page_num
=
page_num
page_change
=
True
page_split
=
False
continue
if
re
.
findall
(
'writing-mode:(.*?);'
,
divs
)
==
[
'False'
]
or
re
.
match
(
r'^[0-9\s\n]+$'
,
c
.
text
)
or
re
.
match
(
r"^第\s+\d+\s+页$"
,
c
.
text
):
#如果不显示或者纯数字
if
re
.
findall
(
'writing-mode:(.*?);'
,
divs
)
==
[
'False'
]
or
re
.
match
(
r'^[0-9\s\n]+$'
,
c
.
text
)
or
re
.
match
(
r"^第\s+\d+\s+页$"
,
c
.
text
):
# 如果不显示或者纯数字
continue
if
len
(
c
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
))
<=
1
:
#
去掉有效字符小于1的行
if
len
(
c
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
))
<=
1
:
#
去掉有效字符小于1的行
continue
sps
=
c
.
find_all
(
'span'
)
if
not
sps
:
continue
line_ff_fs_s
=
[]
#
有效字符大于1的集合
line_ff_fs_s2
=
[]
#
有效字符为1的集合
for
sp
in
sps
:
#
如果一行中有多个不同样式的
sp_len
=
len
(
sp
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
))
line_ff_fs_s
=
[]
#
有效字符大于1的集合
line_ff_fs_s2
=
[]
#
有效字符为1的集合
for
sp
in
sps
:
#
如果一行中有多个不同样式的
sp_len
=
len
(
sp
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
))
if
sp_len
>
0
:
st
=
sp
.
get
(
'style'
)
if
st
:
ff_fs
=
(
re
.
findall
(
'font-size:(
\
d+)px'
,
st
),
re
.
findall
(
'font-family:(.*?);'
,
st
),
len
(
sp
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
)))
if
sp_len
==
1
:
#过滤一个有效字符的span
ff_fs
=
(
re
.
findall
(
'font-size:(
\
d+)px'
,
st
),
re
.
findall
(
'font-family:(.*?);'
,
st
),
len
(
sp
.
text
.
replace
(
"
\n
"
,
""
)
.
replace
(
" "
,
""
)))
if
sp_len
==
1
:
# 过滤一个有效字符的span
line_ff_fs_s2
.
append
(
ff_fs
)
else
:
line_ff_fs_s
.
append
(
ff_fs
)
if
len
(
line_ff_fs_s
)
==
0
:
#
如果为空,则以一个有效字符span为准
if
len
(
line_ff_fs_s2
)
>
0
:
if
len
(
line_ff_fs_s
)
==
0
:
#
如果为空,则以一个有效字符span为准
if
len
(
line_ff_fs_s2
)
>
0
:
line_ff_fs_s
=
line_ff_fs_s2
else
:
if
len
(
c
.
text
)
>
0
:
if
len
(
c
.
text
)
>
0
:
page_change
=
False
continue
fs
,
ff
=
__check_fs_ff
(
line_ff_fs_s
,
cur_fs
,
cur_ff
)
fs
,
ff
=
__check_fs_ff
(
line_ff_fs_s
,
cur_fs
,
cur_ff
)
if
not
cur_ff
:
cur_ff
=
ff
if
not
cur_fs
:
cur_fs
=
fs
if
(
abs
(
fs
-
cur_fs
)
<=
1
and
ff
==
cur_ff
):
#
风格和字体都没改变
if
abs
(
fs
-
cur_fs
)
<=
1
and
ff
==
cur_ff
:
#
风格和字体都没改变
cur_text
+=
c
.
text
cur_fs
=
fs
page_change
=
False
if
len
(
cur_text
.
splitlines
())
>
3
:
#
连续多行则fs_increasing不再生效
if
len
(
cur_text
.
splitlines
())
>
3
:
#
连续多行则fs_increasing不再生效
fs_increasing
=
False
else
:
if
page_change
and
cur_fs
>
fs
+
1
:
#
翻页,(字体变小) 大概率是页眉,跳过c.text。-----有可能切掉一行文本
if
page_change
and
cur_fs
>
fs
+
1
:
#
翻页,(字体变小) 大概率是页眉,跳过c.text。-----有可能切掉一行文本
page_change
=
False
continue
if
last_is_title
:
#
如果上一个为title
if
__isTitle
(
cur_text
)
or
fs_increasing
:
#
连续多个title 或者 有变大标识的
if
last_is_title
:
#
如果上一个为title
if
__isTitle
(
cur_text
)
or
fs_increasing
:
#
连续多个title 或者 有变大标识的
last_text
=
last_text
+
cur_text
last_is_title
=
True
fs_increasing
=
False
else
:
append_document
(
snippets
,
last_text
,
__appendPara
(
cur_text
),
callbacks
,
cur_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
)
append_document
(
snippets
,
last_text
,
__appendPara
(
cur_text
),
callbacks
,
cur_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
)
page_split
=
True
last_text
=
''
last_is_title
=
False
fs_increasing
=
int
(
fs
)
>
int
(
cur_fs
)
#字体变大
last_is_title
=
False
fs_increasing
=
int
(
fs
)
>
int
(
cur_fs
)
#
字体变大
else
:
if
len
(
last_text
)
>
0
and
__checkV
(
last_text
):
#过滤部分文本
#将跨页的两段或者行数较少的文本合并
append_document
(
snippets
,
__appendPara
(
last_text
),
""
,
callbacks
,
last_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
,
need_append
=
len
(
last_text
.
splitlines
())
<=
2
or
page_change
)
page_split
=
True
if
len
(
last_text
)
>
0
and
__checkV
(
last_text
):
# 过滤部分文本
# 将跨页的两段或者行数较少的文本合并
append_document
(
snippets
,
__appendPara
(
last_text
),
""
,
callbacks
,
last_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
,
need_append
=
len
(
last_text
.
splitlines
())
<=
2
or
page_change
)
page_split
=
True
last_text
=
cur_text
last_is_title
=
__isTitle
(
last_text
)
or
fs_increasing
fs_increasing
=
int
(
fs
)
>
int
(
cur_fs
)
...
...
@@ -288,9 +322,10 @@ def __pdf_loader(filepath:str,sentence_size:int = 0,metadata = None,callbacks =
last_fs
=
cur_fs
cur_fs
=
fs
cur_ff
=
ff
cur_text
=
c
.
text
cur_text
=
c
.
text
page_change
=
False
append_document
(
snippets
,
last_text
,
__appendPara
(
cur_text
),
callbacks
,
cur_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
)
append_document
(
snippets
,
last_text
,
__appendPara
(
cur_text
),
callbacks
,
cur_fs
,
page_num
if
page_split
else
last_page_num
,
metadata
)
if
sentence_size
>
0
:
return
split
(
snippets
,
sentence_size
)
return
split
(
snippets
,
sentence_size
)
return
snippets
src/loader/zh_title_enhance.py
View file @
493cdd59
...
...
@@ -33,7 +33,7 @@ def is_possible_title(
title_max_word_length
:
int
=
20
,
non_alpha_threshold
:
float
=
0.5
,
)
->
bool
:
"""Checks to see if the text passes all
of
the checks for a valid title.
"""Checks to see if the text passes all the checks for a valid title.
Parameters
----------
...
...
src/pgdb/chat/c_db.py
View file @
493cdd59
import
psycopg2
from
psycopg2
import
OperationalError
,
InterfaceError
class
UPostgresDB
:
'''
"""
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
...
...
@@ -18,8 +19,9 @@ class UPostgresDB:
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
'''
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
"""
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
self
.
host
=
host
self
.
database
=
database
self
.
user
=
user
...
...
@@ -35,7 +37,7 @@ class UPostgresDB:
database
=
self
.
database
,
user
=
self
.
user
,
password
=
self
.
password
,
port
=
self
.
port
port
=
self
.
port
)
self
.
cur
=
self
.
conn
.
cursor
()
except
Exception
as
e
:
...
...
@@ -45,7 +47,7 @@ class UPostgresDB:
try
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
self
.
cur
.
execute
(
query
)
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
except
InterfaceError
as
e
:
print
(
f
"数据库连接已经关闭: {e}"
)
...
...
@@ -53,8 +55,8 @@ class UPostgresDB:
print
(
f
"数据库连接出现问题: {e}"
)
self
.
connect
()
self
.
retry_execute
(
query
)
except
Exception
as
e
:
print
(
f
"执行sql语句出现错误: {e}"
)
except
Exception
as
e
:
print
(
f
"执行sql语句出现错误: {e}"
)
self
.
conn
.
rollback
()
def
retry_execute
(
self
,
query
):
...
...
@@ -69,16 +71,16 @@ class UPostgresDB:
try
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
self
.
cur
.
execute
(
query
,
args
)
self
.
conn
.
commit
()
self
.
cur
.
execute
(
query
,
args
)
self
.
conn
.
commit
()
except
InterfaceError
as
e
:
print
(
f
"数据库连接已经关闭: {e}"
)
except
OperationalError
as
e
:
print
(
f
"数据库操作出现问题: {e}"
)
self
.
connect
()
self
.
retry_execute_args
(
query
,
args
)
except
Exception
as
e
:
print
(
f
"执行sql语句出现错误: {e}"
)
except
Exception
as
e
:
print
(
f
"执行sql语句出现错误: {e}"
)
self
.
conn
.
rollback
()
def
retry_execute_args
(
self
,
query
,
args
):
...
...
@@ -89,7 +91,6 @@ class UPostgresDB:
print
(
f
"重新执行sql语句再次出现错误: {type(e).__name__}: {e}"
)
self
.
conn
.
rollback
()
def
search
(
self
,
query
,
params
=
None
):
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
...
...
@@ -97,7 +98,7 @@ class UPostgresDB:
def
fetchall
(
self
):
return
self
.
cur
.
fetchall
()
def
fetchone
(
self
):
return
self
.
cur
.
fetchone
()
...
...
@@ -109,8 +110,8 @@ class UPostgresDB:
try
:
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
src/pgdb/chat/c_user_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
import
json
TABLE_USER
=
"""
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
...
...
@@ -13,14 +14,15 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
"""
class
CUser
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO c_user(user_id, account, password) VALUES (
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
]
))
def
create_table
(
self
):
query
=
TABLE_USER
self
.
db
.
execute
(
query
)
\ No newline at end of file
self
.
db
.
execute
(
query
)
src/pgdb/chat/chat_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
import
json
TABLE_CHAT
=
"""
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
...
...
@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
"""
class
Chat
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
...
...
@@ -24,9 +26,9 @@ class Chat:
# 插入数据
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (
%
s,
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
]
))
# 创建表
def
create_table
(
self
):
query
=
TABLE_CHAT
self
.
db
.
execute
(
query
)
\ No newline at end of file
self
.
db
.
execute
(
query
)
src/pgdb/chat/turn_qa_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
import
json
TABLE_CHAT
=
"""
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
...
...
@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
"""
class
TurnQa
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
...
...
@@ -28,9 +30,9 @@ class TurnQa:
# 插入数据
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (
%
s,
%
s,
%
s,
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
],
value
[
4
],
value
[
5
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
],
value
[
4
],
value
[
5
]
))
# 创建表
def
create_table
(
self
):
query
=
TABLE_CHAT
self
.
db
.
execute
(
query
)
\ No newline at end of file
self
.
db
.
execute
(
query
)
src/pgdb/knowledge/callback.py
View file @
493cdd59
import
os
,
sys
from
os
import
path
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
from
abc
import
ABC
,
abstractmethod
import
json
from
typing
import
List
,
Any
,
Tuple
,
Dict
from
typing
import
List
,
Any
,
Tuple
,
Dict
from
langchain.schema
import
Document
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
,
str2hash_base64
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
,
str2hash_base64
class
DocumentCallback
(
ABC
):
@abstractmethod
#向量库储存前文档处理--
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
@abstractmethod
#
向量库储存前文档处理--
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
pass
@abstractmethod
#向量库查询后文档处理--用于结构建立
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
List
[
Tuple
[
Document
,
float
]]:
#向量库查询后文档处理
@abstractmethod
# 向量库查询后文档处理--用于结构建立
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
\
List
[
Tuple
[
Document
,
float
]]:
# 向量库查询后文档处理
pass
class
DefaultDocumentCallback
(
DocumentCallback
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
output_doc
=
[]
for
doc
in
documents
:
if
"next_doc"
in
doc
.
metadata
:
...
...
@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback):
doc
.
metadata
.
pop
(
"next_doc"
)
output_doc
.
append
(
doc
)
return
output_doc
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
List
[
Tuple
[
Document
,
float
]]:
#向量库查询后文档处理
output_doc
:
List
[
Tuple
[
Document
,
float
]]
=
[]
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
\
List
[
Tuple
[
Document
,
float
]]:
# 向量库查询后文档处理
output_doc
:
List
[
Tuple
[
Document
,
float
]]
=
[]
exist_hash
=
[]
for
doc
,
score
in
documents
:
for
doc
,
score
in
documents
:
print
(
exist_hash
)
dochash
=
str2hash_base64
(
doc
.
page_content
)
if
dochash
in
exist_hash
:
continue
else
:
exist_hash
.
append
(
dochash
)
output_doc
.
append
((
doc
,
score
))
output_doc
.
append
((
doc
,
score
))
if
len
(
output_doc
)
>
number
:
return
output_doc
fordoc
=
doc
while
(
"next_hash"
in
fordoc
.
metadata
)
:
if
len
(
fordoc
.
metadata
[
"next_hash"
])
>
0
:
while
"next_hash"
in
fordoc
.
metadata
:
if
len
(
fordoc
.
metadata
[
"next_hash"
])
>
0
:
if
fordoc
.
metadata
[
"next_hash"
]
in
exist_hash
:
break
else
:
...
...
@@ -50,11 +54,11 @@ class DefaultDocumentCallback(DocumentCallback):
content
=
docstore
.
TXT_DOC
.
search
(
fordoc
.
metadata
[
"next_hash"
])
if
content
:
fordoc
=
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
output_doc
.
append
((
fordoc
,
score
))
output_doc
.
append
((
fordoc
,
score
))
if
len
(
output_doc
)
>
number
:
return
output_doc
return
output_doc
else
:
break
else
:
break
return
output_doc
\ No newline at end of file
return
output_doc
src/pgdb/knowledge/k_db.py
View file @
493cdd59
import
psycopg2
class
PostgresDB
:
'''
"""
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
...
...
@@ -17,8 +18,9 @@ class PostgresDB:
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
'''
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
"""
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
self
.
host
=
host
self
.
database
=
database
self
.
user
=
user
...
...
@@ -28,28 +30,29 @@ class PostgresDB:
self
.
cur
=
None
def
connect
(
self
):
self
.
conn
=
psycopg2
.
connect
(
self
.
conn
=
psycopg2
.
connect
(
host
=
self
.
host
,
database
=
self
.
database
,
user
=
self
.
user
,
password
=
self
.
password
,
port
=
self
.
port
port
=
self
.
port
)
self
.
cur
=
self
.
conn
.
cursor
()
def
execute
(
self
,
query
):
try
:
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
def
execute_args
(
self
,
query
,
args
):
try
:
self
.
cur
.
execute
(
query
,
args
)
self
.
conn
.
commit
()
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
cur
.
execute
(
query
,
args
)
self
.
conn
.
commit
()
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
def
search
(
self
,
query
,
params
=
None
):
...
...
@@ -64,8 +67,8 @@ class PostgresDB:
def
format
(
self
,
query
):
try
:
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
cur
.
execute
(
query
)
self
.
conn
.
commit
()
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
src/pgdb/knowledge/pgsqldocstore.py
View file @
493cdd59
import
sys
from
os
import
path
# 这里相当于把当前目录添加到pythonpath中
sys
.
path
.
append
(
path
.
dirname
(
path
.
abspath
(
__file__
)))
from
typing
import
List
,
Union
,
Dict
,
Optional
from
typing
import
List
,
Union
,
Dict
,
Optional
from
langchain.docstore.base
import
AddableMixin
,
Docstore
from
k_db
import
PostgresDB
from
.txt_doc_table
import
TxtDoc
from
.vec_txt_table
import
TxtVector
import
json
,
hashlib
,
base64
import
json
,
hashlib
,
base64
from
langchain.schema
import
Document
def
str2hash_base64
(
input
:
str
)
->
str
:
def
str2hash_base64
(
inp
:
str
)
->
str
:
# return f"%s" % hash(input)
return
base64
.
b64encode
(
hashlib
.
sha1
(
inp
ut
.
encode
())
.
digest
())
.
decode
()
return
base64
.
b64encode
(
hashlib
.
sha1
(
inp
.
encode
())
.
digest
())
.
decode
()
class
PgSqlDocstore
(
Docstore
,
AddableMixin
):
host
:
str
dbname
:
str
username
:
str
password
:
str
port
:
str
class
PgSqlDocstore
(
Docstore
,
AddableMixin
):
host
:
str
dbname
:
str
username
:
str
password
:
str
port
:
str
'''
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
'''
def
__getstate__
(
self
):
return
{
"host"
:
self
.
host
,
"dbname"
:
self
.
dbname
,
"username"
:
self
.
username
,
"password"
:
self
.
password
,
"port"
:
self
.
port
}
return
{
"host"
:
self
.
host
,
"dbname"
:
self
.
dbname
,
"username"
:
self
.
username
,
"password"
:
self
.
password
,
"port"
:
self
.
port
}
def
__setstate__
(
self
,
info
):
self
.
__init__
(
info
)
def
__init__
(
self
,
info
:
dict
,
reset
:
bool
=
False
):
def
__init__
(
self
,
info
:
dict
,
reset
:
bool
=
False
):
self
.
host
=
info
[
"host"
]
self
.
dbname
=
info
[
"dbname"
]
self
.
username
=
info
[
"username"
]
self
.
password
=
info
[
"password"
]
self
.
port
=
info
[
"port"
]
if
"port"
in
info
else
"5432"
;
self
.
pgdb
=
PostgresDB
(
self
.
host
,
self
.
dbname
,
self
.
username
,
self
.
password
,
port
=
self
.
port
)
self
.
port
=
info
[
"port"
]
if
"port"
in
info
else
"5432"
self
.
pgdb
=
PostgresDB
(
self
.
host
,
self
.
dbname
,
self
.
username
,
self
.
password
,
port
=
self
.
port
)
self
.
TXT_DOC
=
TxtDoc
(
self
.
pgdb
)
self
.
VEC_TXT
=
TxtVector
(
self
.
pgdb
)
if
reset
:
...
...
@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin):
self
.
VEC_TXT
.
drop_table
()
self
.
TXT_DOC
.
create_table
()
self
.
VEC_TXT
.
create_table
()
def
__sub_init__
(
self
):
if
not
self
.
pgdb
.
conn
:
self
.
pgdb
.
connect
()
'''
从本地库中查找向量对应的文本段落,封装成Document返回
'''
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
...
...
@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin):
return
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
else
:
return
Document
()
'''
从本地库中删除向量对应的文本,批量删除
'''
def
delete
(
self
,
ids
:
List
)
->
None
:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
pids
=
[]
for
i
d
in
ids
:
anwser
=
self
.
VEC_TXT
.
search
(
i
d
)
for
i
tem
in
ids
:
anwser
=
self
.
VEC_TXT
.
search
(
i
tem
)
pids
.
append
(
anwser
[
0
])
self
.
VEC_TXT
.
delete
(
ids
)
self
.
TXT_DOC
.
delete
(
pids
)
'''
向本地库添加向量和文本信息
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
'''
def
add
(
self
,
texts
:
Dict
[
str
,
Document
])
->
None
:
# for vec,doc in texts.items():
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
paragraph_hashs
=
[]
#
hash,text
paragraph_hashs
=
[]
#
hash,text
paragraph_txts
=
[]
vec_inserts
=
[]
for
vec
,
doc
in
texts
.
items
():
for
vec
,
doc
in
texts
.
items
():
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
print
(
txt_hash
)
vec_inserts
.
append
((
vec
,
doc
.
page_content
,
txt_hash
))
vec_inserts
.
append
((
vec
,
doc
.
page_content
,
txt_hash
))
if
txt_hash
not
in
paragraph_hashs
:
paragraph_hashs
.
append
(
txt_hash
)
paragraph
=
doc
.
metadata
[
"paragraph"
]
doc
.
metadata
.
pop
(
"paragraph"
)
paragraph_txts
.
append
((
txt_hash
,
paragraph
,
json
.
dumps
(
doc
.
metadata
,
ensure_ascii
=
False
)))
paragraph_txts
.
append
((
txt_hash
,
paragraph
,
json
.
dumps
(
doc
.
metadata
,
ensure_ascii
=
False
)))
# print(paragraph_txts)
self
.
TXT_DOC
.
insert
(
paragraph_txts
)
self
.
VEC_TXT
.
insert
(
vec_inserts
)
...
...
@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin):
class
InMemorySecondaryDocstore
(
Docstore
,
AddableMixin
):
"""Simple in memory docstore in the form of a dict."""
def
__init__
(
self
,
_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
,
_sec_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
):
def
__init__
(
self
,
_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
,
_sec_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
):
"""Initialize with dict."""
self
.
_dict
=
_dict
if
_dict
is
not
None
else
{}
self
.
_sec_dict
=
_sec_dict
if
_sec_dict
is
not
None
else
{}
...
...
@@ -123,19 +132,19 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if
overlapping
:
raise
ValueError
(
f
"Tried to add ids that already exist: {overlapping}"
)
self
.
_dict
=
{
**
self
.
_dict
,
**
texts
}
dict1
=
{}
dict_sec
=
{}
for
vec
,
doc
in
texts
.
items
():
for
vec
,
doc
in
texts
.
items
():
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
metadata
=
doc
.
metadata
metadata
=
doc
.
metadata
paragraph
=
metadata
.
pop
(
'paragraph'
)
# metadata.update({"paragraph_id":txt_hash})
metadata
[
'paragraph_id'
]
=
txt_hash
dict_sec
[
txt_hash
]
=
Document
(
page_content
=
paragraph
,
metadata
=
metadata
)
dict1
[
vec
]
=
Document
(
page_content
=
doc
.
page_content
,
metadata
=
{
'paragraph_id'
:
txt_hash
})
self
.
_dict
=
{
**
self
.
_dict
,
**
dict1
}
self
.
_sec_dict
=
{
**
self
.
_sec_dict
,
**
dict_sec
}
metadata
[
'paragraph_id'
]
=
txt_hash
dict_sec
[
txt_hash
]
=
Document
(
page_content
=
paragraph
,
metadata
=
metadata
)
dict1
[
vec
]
=
Document
(
page_content
=
doc
.
page_content
,
metadata
=
{
'paragraph_id'
:
txt_hash
})
self
.
_dict
=
{
**
self
.
_dict
,
**
dict1
}
self
.
_sec_dict
=
{
**
self
.
_sec_dict
,
**
dict_sec
}
def
delete
(
self
,
ids
:
List
)
->
None
:
"""Deleting IDs from in memory dictionary."""
...
...
@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if
not
overlapping
:
raise
ValueError
(
f
"Tried to delete ids that does not exist: {ids}"
)
for
_id
in
ids
:
self
.
_sec_dict
.
pop
(
self
.
_dict
[
id
]
.
metadata
[
'paragraph_id'
])
self
.
_sec_dict
.
pop
(
self
.
_dict
[
_
id
]
.
metadata
[
'paragraph_id'
])
self
.
_dict
.
pop
(
_id
)
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
...
...
@@ -159,4 +168,4 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
return
f
"ID {search} not found."
else
:
print
(
self
.
_dict
[
search
]
.
page_content
)
return
self
.
_sec_dict
[
self
.
_dict
[
search
]
.
metadata
[
'paragraph_id'
]]
\ No newline at end of file
return
self
.
_sec_dict
[
self
.
_dict
[
search
]
.
metadata
[
'paragraph_id'
]]
src/pgdb/knowledge/similarity.py
View file @
493cdd59
import
os
,
sys
import
re
,
time
import
re
,
time
from
os
import
path
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
import
copy
from
typing
import
List
,
OrderedDict
,
Any
,
Optional
,
Tuple
,
Dict
from
typing
import
List
,
OrderedDict
,
Any
,
Optional
,
Tuple
,
Dict
from
src.pgdb.knowledge.pgsqldocstore
import
InMemorySecondaryDocstore
from
langchain.vectorstores.faiss
import
FAISS
,
dependable_faiss_import
from
langchain.vectorstores.faiss
import
FAISS
,
dependable_faiss_import
from
langchain.schema
import
Document
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
from
langchain.embeddings.huggingface
import
(
HuggingFaceEmbeddings
,
)
...
...
@@ -22,50 +22,61 @@ from langchain.callbacks.manager import (
)
from
src.loader
import
load
from
langchain.embeddings.base
import
Embeddings
from
src.pgdb.knowledge.callback
import
DocumentCallback
,
DefaultDocumentCallback
from
src.pgdb.knowledge.callback
import
DocumentCallback
,
DefaultDocumentCallback
def
singleton
(
cls
):
instances
=
{}
def
get_instance
(
*
args
,
**
kwargs
):
if
cls
not
in
instances
:
instances
[
cls
]
=
cls
(
*
args
,
**
kwargs
)
return
instances
[
cls
]
return
get_instance
@singleton
class
EmbeddingFactory
:
def
__init__
(
self
,
path
:
str
):
def
__init__
(
self
,
path
:
str
):
self
.
path
=
path
self
.
embedding
=
HuggingFaceEmbeddings
(
model_name
=
path
)
def
get_embedding
(
self
):
return
self
.
embedding
def
GetEmbding
(
path
:
str
)
->
Embeddings
:
def
GetEmbding
(
_path
:
str
)
->
Embeddings
:
# return HuggingFaceEmbeddings(model_name=path)
return
EmbeddingFactory
(
path
)
.
get_embedding
()
return
EmbeddingFactory
(
_path
)
.
get_embedding
()
import
operator
from
langchain.vectorstores.utils
import
DistanceStrategy
import
numpy
as
np
class
RE_FAISS
(
FAISS
):
#去重,并保留metadate
def
_tuple_deduplication
(
self
,
tuple_input
:
List
[
Tuple
[
Document
,
float
]])
->
List
[
Tuple
[
Document
,
float
]]:
# 去重,并保留metadate
@staticmethod
def
_tuple_deduplication
(
tuple_input
:
List
[
Tuple
[
Document
,
float
]])
->
List
[
Tuple
[
Document
,
float
]]:
deduplicated_dict
=
OrderedDict
()
for
doc
,
scores
in
tuple_input
:
for
doc
,
scores
in
tuple_input
:
page_content
=
doc
.
page_content
metadata
=
doc
.
metadata
if
page_content
not
in
deduplicated_dict
:
deduplicated_dict
[
page_content
]
=
(
metadata
,
scores
)
deduplicated_documents
=
[(
Document
(
page_content
=
key
,
metadata
=
value
[
0
]),
value
[
1
])
for
key
,
value
in
deduplicated_dict
.
items
()]
deduplicated_dict
[
page_content
]
=
(
metadata
,
scores
)
deduplicated_documents
=
[(
Document
(
page_content
=
key
,
metadata
=
value
[
0
]),
value
[
1
])
for
key
,
value
in
deduplicated_dict
.
items
()]
return
deduplicated_documents
def
similarity_search_with_score_by_vector
(
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
filter
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
fetch_k
:
int
=
20
,
**
kwargs
:
Any
,
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
filter
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
fetch_k
:
int
=
20
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
Document
,
float
]]:
faiss
=
dependable_faiss_import
()
vector
=
np
.
array
([
embedding
],
dtype
=
np
.
float32
)
...
...
@@ -96,7 +107,7 @@ class RE_FAISS(FAISS):
cmp
=
(
operator
.
ge
if
self
.
distance_strategy
in
(
DistanceStrategy
.
MAX_INNER_PRODUCT
,
DistanceStrategy
.
JACCARD
)
in
(
DistanceStrategy
.
MAX_INNER_PRODUCT
,
DistanceStrategy
.
JACCARD
)
else
operator
.
le
)
docs
=
[
...
...
@@ -104,19 +115,20 @@ class RE_FAISS(FAISS):
for
doc
,
similarity
in
docs
if
cmp
(
similarity
,
score_threshold
)
]
if
"doc_callback"
in
kwargs
:
if
hasattr
(
kwargs
[
"doc_callback"
],
'after_search'
):
docs
=
kwargs
[
"doc_callback"
]
.
after_search
(
self
.
docstore
,
docs
,
number
=
k
)
docs
=
kwargs
[
"doc_callback"
]
.
after_search
(
self
.
docstore
,
docs
,
number
=
k
)
return
docs
[:
k
]
def
max_marginal_relevance_search_by_vector
(
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
fetch_k
:
int
=
20
,
lambda_mult
:
float
=
0.5
,
filter
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
:
Any
,
self
,
embedding
:
List
[
float
],
k
:
int
=
4
,
fetch_k
:
int
=
20
,
lambda_mult
:
float
=
0.5
,
filter
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Document
]:
"""Return docs selected using the maximal marginal relevance.
...
...
@@ -141,52 +153,63 @@ class RE_FAISS(FAISS):
docs_and_scores
=
self
.
_tuple_deduplication
(
docs_and_scores
)
if
"doc_callback"
in
kwargs
:
if
hasattr
(
kwargs
[
"doc_callback"
],
'after_search'
):
docs_and_scores
=
kwargs
[
"doc_callback"
]
.
after_search
(
self
.
docstore
,
docs_and_scores
,
number
=
k
)
docs_and_scores
=
kwargs
[
"doc_callback"
]
.
after_search
(
self
.
docstore
,
docs_and_scores
,
number
=
k
)
return
[
doc
for
doc
,
_
in
docs_and_scores
]
def
getFAISS
(
embedding_model_name
:
str
,
store_path
:
str
,
info
:
dict
=
None
,
index_name
:
str
=
"index"
,
is_pgsql
:
bool
=
True
,
reset
:
bool
=
False
)
->
RE_FAISS
:
embeddings
=
GetEmbding
(
path
=
embedding_model_name
)
docstore1
:
PgSqlDocstore
=
None
def
getFAISS
(
embedding_model_name
:
str
,
store_path
:
str
,
info
:
dict
=
None
,
index_name
:
str
=
"index"
,
is_pgsql
:
bool
=
True
,
reset
:
bool
=
False
)
->
RE_FAISS
:
embeddings
=
GetEmbding
(
_path
=
embedding_model_name
)
docstore1
:
PgSqlDocstore
=
None
if
is_pgsql
:
if
info
and
"host"
in
info
and
"dbname"
in
info
and
"username"
in
info
and
"password"
in
info
:
docstore1
=
PgSqlDocstore
(
info
,
reset
=
reset
)
docstore1
=
PgSqlDocstore
(
info
,
reset
=
reset
)
else
:
docstore1
=
InMemorySecondaryDocstore
()
if
not
path
.
exists
(
store_path
):
os
.
makedirs
(
store_path
,
exist_ok
=
True
)
if
store_path
is
None
or
len
(
store_path
)
<=
0
or
not
path
.
exists
(
path
.
join
(
store_path
,
index_name
+
".faiss"
))
or
reset
:
os
.
makedirs
(
store_path
,
exist_ok
=
True
)
if
store_path
is
None
or
len
(
store_path
)
<=
0
or
not
path
.
exists
(
path
.
join
(
store_path
,
index_name
+
".faiss"
))
or
reset
:
print
(
"create new faiss"
)
index
=
faiss
.
IndexFlatL2
(
len
(
embeddings
.
embed_documents
([
"a"
])[
0
]))
#根据embeddings向量维度设置
return
RE_FAISS
(
embedding_function
=
embeddings
.
client
.
encode
,
index
=
index
,
docstore
=
docstore1
,
index_to_docstore_id
=
{})
index
=
faiss
.
IndexFlatL2
(
len
(
embeddings
.
embed_documents
([
"a"
])[
0
]))
# 根据embeddings向量维度设置
return
RE_FAISS
(
embedding_function
=
embeddings
.
client
.
encode
,
index
=
index
,
docstore
=
docstore1
,
index_to_docstore_id
=
{})
else
:
print
(
"load_local faiss"
)
_faiss
=
RE_FAISS
.
load_local
(
folder_path
=
store_path
,
index_name
=
index_name
,
embeddings
=
embeddings
)
if
docstore1
and
is_pgsql
:
#如果外部参数调整,更新docstore
_faiss
=
RE_FAISS
.
load_local
(
folder_path
=
store_path
,
index_name
=
index_name
,
embeddings
=
embeddings
)
if
docstore1
and
is_pgsql
:
#
如果外部参数调整,更新docstore
_faiss
.
docstore
=
docstore1
return
_faiss
class
VectorStore_FAISS
(
FAISS
):
def
__init__
(
self
,
embedding_model_name
:
str
,
store_path
:
str
,
index_name
:
str
=
"index"
,
info
:
dict
=
None
,
is_pgsql
:
bool
=
True
,
show_number
=
5
,
threshold
=
0.8
,
reset
:
bool
=
False
,
doc_callback
:
DocumentCallback
=
DefaultDocumentCallback
()):
def
__init__
(
self
,
embedding_model_name
:
str
,
store_path
:
str
,
index_name
:
str
=
"index"
,
info
:
dict
=
None
,
is_pgsql
:
bool
=
True
,
show_number
=
5
,
threshold
=
0.8
,
reset
:
bool
=
False
,
doc_callback
:
DocumentCallback
=
DefaultDocumentCallback
()):
self
.
info
=
info
self
.
embedding_model_name
=
embedding_model_name
self
.
store_path
=
path
.
join
(
store_path
,
index_name
)
self
.
store_path
=
path
.
join
(
store_path
,
index_name
)
if
not
path
.
exists
(
self
.
store_path
):
os
.
makedirs
(
self
.
store_path
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
store_path
,
exist_ok
=
True
)
self
.
index_name
=
index_name
self
.
show_number
=
show_number
self
.
search_number
=
self
.
show_number
*
3
self
.
search_number
=
self
.
show_number
*
3
self
.
threshold
=
threshold
self
.
_faiss
=
getFAISS
(
self
.
embedding_model_name
,
self
.
store_path
,
info
=
info
,
index_name
=
self
.
index_name
,
is_pgsql
=
is_pgsql
,
reset
=
reset
)
self
.
_faiss
=
getFAISS
(
self
.
embedding_model_name
,
self
.
store_path
,
info
=
info
,
index_name
=
self
.
index_name
,
is_pgsql
=
is_pgsql
,
reset
=
reset
)
self
.
doc_callback
=
doc_callback
def
get_text_similarity_with_score
(
self
,
text
:
str
,
**
kwargs
):
score_threshold
=
(
1
-
self
.
threshold
)
*
math
.
sqrt
(
2
)
docs
=
self
.
_faiss
.
similarity_search_with_score
(
query
=
text
,
k
=
self
.
search_number
,
score_threshold
=
score_threshold
,
doc_callback
=
self
.
doc_callback
,
**
kwargs
)
def
get_text_similarity_with_score
(
self
,
text
:
str
,
**
kwargs
):
score_threshold
=
(
1
-
self
.
threshold
)
*
math
.
sqrt
(
2
)
docs
=
self
.
_faiss
.
similarity_search_with_score
(
query
=
text
,
k
=
self
.
search_number
,
score_threshold
=
score_threshold
,
doc_callback
=
self
.
doc_callback
,
**
kwargs
)
return
[
doc
for
doc
,
similarity
in
docs
][:
self
.
show_number
]
def
get_text_similarity
(
self
,
text
:
str
,
**
kwargs
):
docs
=
self
.
_faiss
.
similarity_search
(
query
=
text
,
k
=
self
.
search_number
,
doc_callback
=
self
.
doc_callback
,
**
kwargs
)
def
get_text_similarity
(
self
,
text
:
str
,
**
kwargs
):
docs
=
self
.
_faiss
.
similarity_search
(
query
=
text
,
k
=
self
.
search_number
,
doc_callback
=
self
.
doc_callback
,
**
kwargs
)
return
docs
[:
self
.
show_number
]
# #去重,并保留metadate
# def _tuple_deduplication(self, tuple_input:List[Document]) -> List[Document]:
# deduplicated_dict = OrderedDict()
...
...
@@ -195,26 +218,29 @@ class VectorStore_FAISS(FAISS):
# metadata = doc.metadata
# if page_content not in deduplicated_dict:
# deduplicated_dict[page_content] = metadata
# deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
# return deduplicated_documents
def
_join_document
(
self
,
docs
:
List
[
Document
])
->
str
:
@staticmethod
def
_join_document
(
docs
:
List
[
Document
])
->
str
:
print
(
docs
)
return
""
.
join
([
doc
.
page_content
for
doc
in
docs
])
def
get_local_doc
(
self
,
docs
:
List
[
Document
]):
@staticmethod
def
get_local_doc
(
docs
:
List
[
Document
]):
ans
=
[]
for
doc
in
docs
:
ans
.
append
({
"page_content"
:
doc
.
page_content
,
"page_number"
:
doc
.
metadata
[
"page_number"
],
"filename"
:
doc
.
metadata
[
"filename"
]})
ans
.
append
({
"page_content"
:
doc
.
page_content
,
"page_number"
:
doc
.
metadata
[
"page_number"
],
"filename"
:
doc
.
metadata
[
"filename"
]})
return
ans
# def _join_document_location(self, docs:List[Document]) -> str:
# 持久化到本地
def
_save_local
(
self
):
self
.
_faiss
.
save_local
(
folder_path
=
self
.
store_path
,
index_name
=
self
.
index_name
)
self
.
_faiss
.
save_local
(
folder_path
=
self
.
store_path
,
index_name
=
self
.
index_name
)
# 添加文档
# Document {
# page_content 段落
...
...
@@ -222,10 +248,10 @@ class VectorStore_FAISS(FAISS):
# page 页码
# }
# }
def
_add_documents
(
self
,
new_docs
:
List
[
Document
],
need_split
:
bool
=
True
,
pattern
:
str
=
r'[?。;\n]'
):
list_of_documents
:
List
[
Document
]
=
[]
def
_add_documents
(
self
,
new_docs
:
List
[
Document
],
need_split
:
bool
=
True
,
pattern
:
str
=
r'[?。;\n]'
):
list_of_documents
:
List
[
Document
]
=
[]
if
self
.
doc_callback
:
new_docs
=
self
.
doc_callback
.
before_store
(
self
.
_faiss
.
docstore
,
new_docs
)
new_docs
=
self
.
doc_callback
.
before_store
(
self
.
_faiss
.
docstore
,
new_docs
)
if
need_split
:
for
doc
in
new_docs
:
words_list
=
re
.
split
(
pattern
,
doc
.
page_content
)
...
...
@@ -240,8 +266,14 @@ class VectorStore_FAISS(FAISS):
else
:
list_of_documents
=
new_docs
self
.
_faiss
.
add_documents
(
list_of_documents
)
def
_add_documents_from_dir
(
self
,
filepaths
=
[],
load_kwargs
:
Optional
[
dict
]
=
{
"mode"
:
"paged"
}):
self
.
_add_documents
(
load
.
loads
(
filepaths
,
**
load_kwargs
))
def
_add_documents_from_dir
(
self
,
filepaths
=
None
,
load_kwargs
=
None
):
if
load_kwargs
is
None
:
load_kwargs
=
{
"mode"
:
"paged"
}
if
filepaths
is
None
:
filepaths
=
[]
self
.
_add_documents
(
load
.
loads
(
filepaths
,
**
load_kwargs
))
def
as_retriever
(
self
,
**
kwargs
:
Any
)
->
VectorStoreRetriever
:
"""
Return VectorStoreRetriever initialized from this VectorStore.
...
...
@@ -303,11 +335,11 @@ class VectorStore_FAISS(FAISS):
default_kwargs
.
update
(
kwargs
[
"search_kwargs"
])
kwargs
[
"search_kwargs"
]
=
default_kwargs
elif
"similarity_score_threshold"
==
kwargs
[
"search_type"
]:
default_kwargs
=
{
'score_threshold'
:
self
.
threshold
,
'k'
:
self
.
show_number
}
if
"search_kwargs"
in
kwargs
:
default_kwargs
=
{
'score_threshold'
:
self
.
threshold
,
'k'
:
self
.
show_number
}
if
"search_kwargs"
in
kwargs
:
default_kwargs
.
update
(
kwargs
[
"search_kwargs"
])
kwargs
[
"search_kwargs"
]
=
default_kwargs
kwargs
[
"search_kwargs"
][
"doc_callback"
]
=
self
.
doc_callback
kwargs
[
"search_kwargs"
][
"doc_callback"
]
=
self
.
doc_callback
tags
=
kwargs
.
pop
(
"tags"
,
None
)
or
[]
tags
.
extend
(
self
.
_faiss
.
_get_retriever_tags
())
print
(
kwargs
)
...
...
@@ -316,20 +348,21 @@ class VectorStore_FAISS(FAISS):
class
VectorStoreRetriever_FAISS
(
VectorStoreRetriever
):
search_k
=
5
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
()
.
__init__
(
**
kwargs
)
if
"k"
in
self
.
search_kwargs
:
self
.
search_k
=
self
.
search_kwargs
[
"k"
]
self
.
search_kwargs
[
"k"
]
=
self
.
search_k
*
2
self
.
search_k
=
self
.
search_kwargs
[
"k"
]
self
.
search_kwargs
[
"k"
]
=
self
.
search_k
*
2
def
_get_relevant_documents
(
self
,
query
:
str
,
*
,
run_manager
:
CallbackManagerForRetrieverRun
self
,
query
:
str
,
*
,
run_manager
:
CallbackManagerForRetrieverRun
)
->
List
[
Document
]:
docs
=
super
()
.
_get_relevant_documents
(
query
=
query
,
run_manager
=
run_manager
)
docs
=
super
()
.
_get_relevant_documents
(
query
=
query
,
run_manager
=
run_manager
)
return
docs
[:
self
.
search_k
]
async
def
_aget_relevant_documents
(
self
,
query
:
str
,
*
,
run_manager
:
AsyncCallbackManagerForRetrieverRun
self
,
query
:
str
,
*
,
run_manager
:
AsyncCallbackManagerForRetrieverRun
)
->
List
[
Document
]:
docs
=
super
()
.
_aget_relevant_documents
(
query
=
query
,
run_manager
=
run_manager
)
docs
=
super
()
.
_aget_relevant_documents
(
query
=
query
,
run_manager
=
run_manager
)
return
docs
[:
self
.
search_k
]
\ No newline at end of file
src/pgdb/knowledge/txt_doc_table.py
View file @
493cdd59
from
.k_db
import
PostgresDB
# paragraph_id BIGSERIAL primary key,
# paragraph_id BIGSERIAL primary key,
TABLE_TXT_DOC
=
"""
create table txt_doc (
hash varchar(40) primary key,
...
...
@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
"""
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
class
TxtDoc
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
...
...
@@ -21,19 +24,20 @@ class TxtDoc:
args
=
[]
for
value
in
texts
:
value
=
list
(
value
)
query
+=
"(
%
s,
%
s,
%
s),"
query
+=
"(
%
s,
%
s,
%
s),"
args
.
extend
(
value
)
query
=
query
[:
len
(
query
)
-
1
]
query
=
query
[:
len
(
query
)
-
1
]
query
+=
f
"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
i
d
in
ids
:
query
=
f
"delete FROM txt_doc WHERE hash =
%
s"
%
(
id
)
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
i
tem
in
ids
:
query
=
f
"delete FROM txt_doc WHERE hash =
%
s"
%
item
self
.
db
.
execute
(
query
)
def
search
(
self
,
id
):
def
search
(
self
,
item
):
query
=
"SELECT text,matadate FROM txt_doc WHERE hash =
%
s"
self
.
db
.
execute_args
(
query
,
[
id
])
self
.
db
.
execute_args
(
query
,
[
item
])
answer
=
self
.
db
.
fetchall
()
if
len
(
answer
)
>
0
:
return
answer
[
0
]
...
...
@@ -60,4 +64,3 @@ class TxtDoc:
query
=
"DROP TABLE txt_doc"
self
.
db
.
format
(
query
)
print
(
"drop table txt_doc ok"
)
src/pgdb/knowledge/vec_txt_table.py
View file @
493cdd59
from
.k_db
import
PostgresDB
TABLE_VEC_TXT
=
"""
CREATE TABLE vec_txt (
vector_id varchar(36) PRIMARY KEY,
...
...
@@ -6,7 +7,9 @@ CREATE TABLE vec_txt (
paragraph_id varchar(40) not null
)
"""
#025a9bee-2eb2-47f5-9722-525e05a0442b
# 025a9bee-2eb2-47f5-9722-525e05a0442b
class
TxtVector
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
self
.
db
=
db
...
...
@@ -16,19 +19,21 @@ class TxtVector:
args
=
[]
for
value
in
vectors
:
value
=
list
(
value
)
query
+=
"(
%
s,
%
s,
%
s),"
query
+=
"(
%
s,
%
s,
%
s),"
args
.
extend
(
value
)
query
=
query
[:
len
(
query
)
-
1
]
query
=
query
[:
len
(
query
)
-
1
]
query
+=
f
"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
# query += ";"
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
id
in
ids
:
query
=
f
"delete FROM vec_txt WHERE vector_id = '
%
s'"
%
(
id
,)
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
item
in
ids
:
query
=
f
"delete FROM vec_txt WHERE vector_id = '
%
s'"
%
(
item
,)
self
.
db
.
execute
(
query
)
def
search
(
self
,
search
:
str
):
query
=
f
"SELECT paragraph_id,text FROM vec_txt WHERE vector_id =
%
s"
self
.
db
.
execute_args
(
query
,[
search
])
self
.
db
.
execute_args
(
query
,
[
search
])
answer
=
self
.
db
.
fetchall
()
print
(
answer
)
return
answer
[
0
]
...
...
@@ -48,4 +53,4 @@ class TxtVector:
if
exists
:
query
=
"DROP TABLE vec_txt"
self
.
db
.
format
(
query
)
print
(
"drop table vec_txt ok"
)
\ No newline at end of file
print
(
"drop table vec_txt ok"
)
test/chat_table_test.py
View file @
493cdd59
import
sys
sys
.
path
.
append
(
"../"
)
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.pgdb.chat.chat_table
import
Chat
from
src.pgdb.chat.c_user_table
import
CUser
from
src.pgdb.chat.turn_qa_table
import
TurnQa
"""测试会话相关数据可的连接"""
c_db
=
UPostgresDB
(
host
=
"localhost"
,
database
=
"laechat"
,
user
=
"postgres"
,
password
=
"chenzl"
,
port
=
5432
)
chat
=
Chat
(
db
=
c_db
)
c_user
=
CUser
(
db
=
c_db
)
turn_qa
=
TurnQa
(
db
=
c_db
)
chat
.
create_table
()
c_user
.
create_table
()
turn_qa
.
create_table
()
# chat_id, user_id, info, deleted
chat
.
insert
([
"3333"
,
"1111"
,
"没有info"
,
0
])
def
test
():
c_db
=
UPostgresDB
(
host
=
"localhost"
,
database
=
"laechat"
,
user
=
"postgres"
,
password
=
"chenzl"
,
port
=
5432
)
chat
=
Chat
(
db
=
c_db
)
c_user
=
CUser
(
db
=
c_db
)
turn_qa
=
TurnQa
(
db
=
c_db
)
chat
.
create_table
()
c_user
.
create_table
()
turn_qa
.
create_table
()
# chat_id, user_id, info, deleted
chat
.
insert
([
"3333"
,
"1111"
,
"没有info"
,
0
])
# user_id, account, password
c_user
.
insert
([
"111"
,
"zhangsan"
,
"111111"
])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa
.
insert
([
"222"
,
"1111"
,
"nihao"
,
"nihao"
,
1
,
0
])
# user_id, account, password
c_user
.
insert
([
"111"
,
"zhangsan"
,
"111111"
])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa
.
insert
([
"222"
,
"1111"
,
"nihao"
,
"nihao"
,
1
,
0
])
\ No newline at end of file
if
__name__
==
"main"
:
test
()
test/k_store_test.py
View file @
493cdd59
import
sys
sys
.
path
.
append
(
"../"
)
import
time
from
src.loader.load
import
loads_path
,
loads
sys
.
path
.
append
(
"../"
)
from
src.loader.load
import
loads_path
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.config.consts
import
(
VEC_DB_DBNAME
,
...
...
@@ -18,24 +18,27 @@ from src.config.consts import (
from
src.loader.callback
import
BaseCallback
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class
localCallback
(
BaseCallback
):
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
if
len
(
title
+
content
)
==
0
:
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
if
len
(
title
+
content
)
==
0
:
return
True
return
(
len
(
title
+
content
)
/
(
len
(
title
.
splitlines
())
+
len
(
content
.
splitlines
()))
<
20
)
or
"思考题"
in
title
return
(
len
(
title
+
content
)
/
(
len
(
title
.
splitlines
())
+
len
(
content
.
splitlines
()))
<
20
)
or
"思考题"
in
title
"""测试资料入库(pgsql和faiss)"""
def
test_faiss_from_dir
():
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
True
)
docs
=
loads_path
(
KNOWLEDGE_PATH
,
mode
=
"paged"
,
sentence_size
=
512
,
callbacks
=
[
localCallback
()])
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
True
)
docs
=
loads_path
(
KNOWLEDGE_PATH
,
mode
=
"paged"
,
sentence_size
=
512
,
callbacks
=
[
localCallback
()])
print
(
len
(
docs
))
last_doc
=
None
docs1
=
[]
...
...
@@ -45,7 +48,8 @@ def test_faiss_from_dir():
continue
if
"font-size"
not
in
doc
.
metadata
or
"page_number"
not
in
doc
.
metadata
:
continue
if
doc
.
metadata
[
"font-size"
]
==
last_doc
.
metadata
[
"font-size"
]
and
doc
.
metadata
[
"page_number"
]
==
last_doc
.
metadata
[
"page_number"
]
and
len
(
doc
.
page_content
)
+
len
(
last_doc
.
page_content
)
<
512
/
4
*
3
:
if
doc
.
metadata
[
"font-size"
]
==
last_doc
.
metadata
[
"font-size"
]
and
doc
.
metadata
[
"page_number"
]
==
\
last_doc
.
metadata
[
"page_number"
]
and
len
(
doc
.
page_content
)
+
len
(
last_doc
.
page_content
)
<
512
/
4
*
3
:
last_doc
.
page_content
+=
doc
.
page_content
else
:
docs1
.
append
(
last_doc
)
...
...
@@ -56,22 +60,26 @@ def test_faiss_from_dir():
print
(
len
(
docs
))
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
for
i
in
range
(
0
,
len
(
docs
),
300
):
vecstore_faiss
.
_add_documents
(
docs
[
i
:
i
+
300
if
i
+
300
<
len
(
docs
)
else
len
(
docs
)],
need_split
=
True
)
vecstore_faiss
.
_add_documents
(
docs
[
i
:
i
+
300
if
i
+
300
<
len
(
docs
)
else
len
(
docs
)],
need_split
=
True
)
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
vecstore_faiss
.
_save_local
()
"""测试faiss向量数据库查询结果"""
def
test_faiss_load
():
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
print
(
vecstore_faiss
.
_join_document
(
vecstore_faiss
.
get_text_similarity
(
"征信业务有什么情况"
)))
if
__name__
==
"__main__"
:
# test_faiss_from_dir()
test_faiss_load
()
\ No newline at end of file
test_faiss_load
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment