Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
493cdd59
Commit
493cdd59
authored
Apr 25, 2024
by
陈正乐
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
代码格式化
parent
9d8ee0af
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
287 additions
and
255 deletions
+287
-255
consts.py
src/config/consts.py
+8
-8
__init__.py
src/llm/__init__.py
+0
-0
baichuan.py
src/llm/baichuan.py
+10
-18
chatglm.py
src/llm/chatglm.py
+39
-36
chatglm_openapi.py
src/llm/chatglm_openapi.py
+9
-8
ernie.py
src/llm/ernie.py
+15
-21
ernie_sdk.py
src/llm/ernie_sdk.py
+11
-4
ernie_with_sdk.py
src/llm/ernie_with_sdk.py
+14
-15
loader.py
src/llm/loader.py
+10
-9
spark.py
src/llm/spark.py
+18
-25
wrapper.py
src/llm/wrapper.py
+5
-9
callback.py
src/loader/callback.py
+2
-1
chinese_text_splitter.py
src/loader/chinese_text_splitter.py
+2
-2
config.py
src/loader/config.py
+0
-0
load.py
src/loader/load.py
+0
-0
zh_title_enhance.py
src/loader/zh_title_enhance.py
+1
-1
c_db.py
src/pgdb/chat/c_db.py
+6
-5
c_user_table.py
src/pgdb/chat/c_user_table.py
+3
-1
chat_table.py
src/pgdb/chat/chat_table.py
+3
-1
turn_qa_table.py
src/pgdb/chat/turn_qa_table.py
+3
-1
callback.py
src/pgdb/knowledge/callback.py
+18
-14
k_db.py
src/pgdb/knowledge/k_db.py
+7
-4
pgsqldocstore.py
src/pgdb/knowledge/pgsqldocstore.py
+37
-28
similarity.py
src/pgdb/knowledge/similarity.py
+0
-0
txt_doc_table.py
src/pgdb/knowledge/txt_doc_table.py
+13
-10
vec_txt_table.py
src/pgdb/knowledge/vec_txt_table.py
+13
-8
chat_table_test.py
test/chat_table_test.py
+21
-15
k_store_test.py
test/k_store_test.py
+19
-11
No files found.
src/config/consts.py
View file @
493cdd59
...
...
@@ -2,19 +2,19 @@
# 资料存储数据库配置
# =============================
VEC_DB_HOST
=
'localhost'
VEC_DB_DBNAME
=
'lae'
VEC_DB_USER
=
'postgres'
VEC_DB_PASSWORD
=
'chenzl'
VEC_DB_PORT
=
'5432'
VEC_DB_DBNAME
=
'lae'
VEC_DB_USER
=
'postgres'
VEC_DB_PASSWORD
=
'chenzl'
VEC_DB_PORT
=
'5432'
# =============================
# 聊天相关数据库配置
# =============================
CHAT_DB_HOST
=
'localhost'
CHAT_DB_DBNAME
=
'laechat'
CHAT_DB_USER
=
'postgres'
CHAT_DB_PASSWORD
=
'chenzl'
CHAT_DB_PORT
=
'5432'
CHAT_DB_DBNAME
=
'laechat'
CHAT_DB_USER
=
'postgres'
CHAT_DB_PASSWORD
=
'chenzl'
CHAT_DB_PORT
=
'5432'
# =============================
# 向量化模型路径配置
...
...
src/llm/__init__.py
View file @
493cdd59
src/llm/baichuan.py
View file @
493cdd59
import
os
from
typing
import
Dict
,
Optional
,
List
from
langchain.llms.base
import
BaseLLM
,
LLM
from
typing
import
Dict
,
Optional
,
List
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
,
AutoModelForCausalLM
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
,
AutoModelForCausalLM
from
transformers.generation.utils
import
GenerationConfig
from
pydantic
import
root_validator
class
BaichuanLLM
(
LLM
):
model_name
:
str
=
"baichuan-inc/Baichuan-13B-Chat"
quantization_bit
:
Optional
[
int
]
=
None
...
...
@@ -19,17 +17,16 @@ class BaichuanLLM(LLM):
tokenizer
:
AutoTokenizer
=
None
model
:
AutoModel
=
None
def
_llm_type
(
self
)
->
str
:
return
"chatglm_local"
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
if
not
values
[
"model_name"
]:
raise
ValueError
(
"No model name provided."
)
model_name
=
values
[
"model_name"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
use_fast
=
False
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
torch
.
float16
,
...
...
@@ -43,7 +40,7 @@ class BaichuanLLM(LLM):
print
(
f
"Quantized to {values['quantization_bit']} bit"
)
model
=
model
.
quantize
(
values
[
"quantization_bit"
])
.
cuda
()
else
:
model
=
model
.
half
()
.
cuda
()
model
=
model
.
half
()
.
cuda
()
model
=
model
.
eval
()
...
...
@@ -51,14 +48,9 @@ class BaichuanLLM(LLM):
values
[
"model"
]
=
model
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
message
=
[]
message
.
append
({
"role"
:
"user"
,
"content"
:
prompt
})
resp
=
self
.
model
.
chat
(
self
.
tokenizer
,
message
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
message
=
[{
"role"
:
"user"
,
"content"
:
prompt
}]
resp
=
self
.
model
.
chat
(
self
.
tokenizer
,
message
)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return
resp
src/llm/chatglm.py
View file @
493cdd59
import
os
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
langchain
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
aiohttp
import
asyncio
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
...
...
@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM):
tokenizer
:
AutoTokenizer
=
None
model
:
AutoModel
=
None
def
_llm_type
(
self
)
->
str
:
return
"chatglm_local"
# @root_validator()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
@staticmethod
def
validate_environment
(
values
:
Dict
)
->
Dict
:
if
not
values
[
"model_name"
]:
raise
ValueError
(
"No model name provided."
)
model_name
=
values
[
"model_name"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if
values
[
"pre_seq_len"
]:
...
...
@@ -72,18 +73,14 @@ class ChatGLMLocLLM(LLM):
values
[
"model"
]
=
model
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return
resp
class
ChatGLMSerLLM
(
LLM
):
class
ChatGLMSerLLM
(
LLM
):
# 模型服务url
url
:
str
=
"http://127.0.0.1:8000"
chat_history
:
dict
=
[]
...
...
@@ -95,7 +92,7 @@ class ChatGLMSerLLM(LLM):
return
"chatglm3-6b"
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
resp
=
self
.
_post
(
url
=
self
.
url
+
"/tokens"
,
query
=
self
.
_construct_query
(
text
))
resp
=
self
.
_post
(
url
=
self
.
url
+
"/tokens"
,
query
=
self
.
_construct_query
(
text
))
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
predictions
=
resp_json
[
'response'
]
...
...
@@ -104,27 +101,28 @@ class ChatGLMSerLLM(LLM):
else
:
return
len
(
text
)
def
convert_data
(
self
,
data
):
@staticmethod
def
convert_data
(
data
):
result
=
[]
for
item
in
data
:
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
return
result
def
_construct_query
(
self
,
prompt
:
str
,
temperature
=
0.95
)
->
Dict
:
def
_construct_query
(
self
,
prompt
:
str
,
temperature
=
0.95
)
->
Dict
:
"""构造请求体
"""
# self.chat_history.append({"role": "user", "content": prompt})
query
=
{
"prompt"
:
prompt
,
"history"
:
self
.
chat_history
,
"history"
:
self
.
chat_history
,
"max_length"
:
4096
,
"top_p"
:
0.7
,
"temperature"
:
temperature
}
return
query
@classmethod
def
_post
(
self
,
url
:
str
,
def
_post
(
cls
,
url
:
str
,
query
:
Dict
)
->
Any
:
"""POST请求
"""
...
...
@@ -135,46 +133,50 @@ class ChatGLMSerLLM(LLM):
headers
=
_headers
,
timeout
=
300
)
return
resp
async
def
_post_stream
(
self
,
url
:
str
,
@staticmethod
async
def
_post_stream
(
url
:
str
,
query
:
Dict
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
)
->
Any
:
"""POST请求
"""
_headers
=
{
"Content_Type"
:
"application/json"
}
async
with
aiohttp
.
ClientSession
()
as
sess
:
async
with
sess
.
post
(
url
,
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
as
response
:
async
with
sess
.
post
(
url
,
json
=
query
,
headers
=
_headers
,
timeout
=
300
)
as
response
:
if
response
.
status
==
200
:
if
stream
and
not
run_manager
:
print
(
'not callable'
)
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_start
(
None
,
None
)
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
_callable
.
on_llm_start
(
None
,
None
)
async
for
chunk
in
response
.
content
.
iter_any
():
# 处理每个块的数据
if
chunk
and
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
# print(chunk.decode("utf-8"),end="")
await
callable
.
on_llm_new_token
(
chunk
.
decode
(
"utf-8"
))
await
_
callable
.
on_llm_new_token
(
chunk
.
decode
(
"utf-8"
))
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_end
(
None
)
for
_
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
_
callable
.
on_llm_end
(
None
)
else
:
raise
ValueError
(
f
'glm 请求异常,http code:{response.status}'
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
,
stream
=
False
,
**
kwargs
:
Any
)
->
str
:
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
# display("==============================")
# display(query)
# post
if
stream
or
self
.
out_stream
:
async
def
_post_stream
():
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
stream
or
self
.
out_stream
)
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
stream
or
self
.
out_stream
)
asyncio
.
run
(
_post_stream
())
return
''
else
:
...
...
@@ -197,9 +199,10 @@ class ChatGLMSerLLM(LLM):
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
**
kwargs
:
Any
,
)
->
str
:
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
self
.
out_stream
)
query
=
self
.
_construct_query
(
prompt
=
prompt
,
temperature
=
kwargs
[
"temperature"
]
if
"temperature"
in
kwargs
else
0.95
)
await
self
.
_post_stream
(
url
=
self
.
url
+
"/stream"
,
query
=
query
,
run_manager
=
run_manager
,
stream
=
self
.
out_stream
)
return
''
@property
...
...
src/llm/chatglm_openapi.py
View file @
493cdd59
import
os
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
import
torch
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoModel
,
AutoConfig
import
langchain
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain_openai
import
OpenAI
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
class
ChatGLMSerLLM
(
OpenAI
):
def
get_token_ids
(
self
,
text
:
str
)
->
List
[
int
]:
...
...
@@ -20,9 +21,9 @@ class ChatGLMSerLLM(OpenAI):
## 发起http请求,获取token_ids
url
=
f
"{self.openai_api_base}/num_tokens"
query
=
{
"prompt"
:
text
,
"model"
:
self
.
model_name
}
_headers
=
{
"Content_Type"
:
"application/json"
,
"Authorization"
:
"chatglm "
+
self
.
openai_api_key
}
resp
=
self
.
_post
(
url
=
url
,
query
=
query
,
headers
=
_headers
)
query
=
{
"prompt"
:
text
,
"model"
:
self
.
model_name
}
_headers
=
{
"Content_Type"
:
"application/json"
,
"Authorization"
:
"chatglm "
+
self
.
openai_api_key
}
resp
=
self
.
_post
(
url
=
url
,
query
=
query
,
headers
=
_headers
)
if
resp
.
status_code
==
200
:
resp_json
=
resp
.
json
()
print
(
resp_json
)
...
...
@@ -32,8 +33,8 @@ class ChatGLMSerLLM(OpenAI):
return
[
len
(
text
)]
@classmethod
def
_post
(
self
,
url
:
str
,
query
:
Dict
,
headers
:
Dict
)
->
Any
:
def
_post
(
cls
,
url
:
str
,
query
:
Dict
,
headers
:
Dict
)
->
Any
:
"""POST请求
"""
_headers
=
{
"Content_Type"
:
"application/json"
}
...
...
src/llm/ernie.py
View file @
493cdd59
...
...
@@ -2,7 +2,7 @@ import logging
import
os
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.schema
import
LLMResult
from
langchain.utils
import
get_from_dict_or_env
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
...
...
@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m
logger
=
logging
.
getLogger
(
__name__
)
class
ModelType
(
Enum
):
ERNIE
=
"ernie"
ERNIE_LITE
=
"ernie-lite"
...
...
@@ -25,7 +26,7 @@ class ModelType(Enum):
LLAMA2_13B
=
"llama2-13b"
LLAMA2_70B
=
"llama2-70b"
QFCN_LLAMA2_7B
=
"qfcn-llama2-7b"
BLOOMZ_7B
=
"bloomz-7b"
BLOOMZ_7B
=
"bloomz-7b"
MODEL_SERVICE_BASE_URL
=
"https://aip.baidubce.com/rpc/2.0/"
...
...
@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = {
ModelType
.
BLOOMZ_7B
:
"ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
,
}
class
ErnieLLM
(
LLM
):
"""
ErnieLLM is a LLM that uses Ernie to generate text.
...
...
@@ -52,7 +54,7 @@ class ErnieLLM(LLM):
access_token
:
Optional
[
str
]
=
""
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
# print(values)
model_name
=
ModelType
(
get_from_dict_or_env
(
values
,
"model_name"
,
"model_name"
,
str
(
ModelType
.
ERNIE
)))
...
...
@@ -65,14 +67,10 @@ class ErnieLLM(LLM):
values
[
"access_token"
]
=
access_token
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
request
=
CompletionRequest
(
messages
=
[
Message
(
"user"
,
prompt
)])
request
=
CompletionRequest
(
messages
=
[
Message
(
"user"
,
prompt
)])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
or
""
,
request
)
try
:
# 你的代码
...
...
@@ -81,10 +79,9 @@ class ErnieLLM(LLM):
return
response
except
Exception
as
e
:
# 处理异常
print
(
"exception:"
,
e
)
print
(
"exception:"
,
e
)
return
e
.
__str__
()
@property
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
...
...
@@ -95,9 +92,10 @@ class ErnieLLM(LLM):
# "name": "ernie",
# }
def
_get_model_service_url
(
model_name
)
->
str
:
# print("_get_model_service_url model_name: ",model_name)
return
MODEL_SERVICE_BASE_URL
+
MODEL_SERVICE_Suffix
[
model_name
]
return
MODEL_SERVICE_BASE_URL
+
MODEL_SERVICE_Suffix
[
model_name
]
class
ErnieChat
(
LLM
):
...
...
@@ -106,16 +104,12 @@ class ErnieChat(LLM):
prefix_messages
:
List
=
Field
(
default_factory
=
list
)
id
:
str
=
""
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
msg
=
user_message
(
prompt
)
request
=
CompletionRequest
(
messages
=
self
.
prefix_messages
+
[
msg
])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
,
request
)
request
=
CompletionRequest
(
messages
=
self
.
prefix_messages
+
[
msg
])
bot
=
ErnieBot
(
_get_model_service_url
(
self
.
model_name
),
self
.
access_token
,
request
)
try
:
# 你的代码
response
=
bot
.
get_response
()
.
result
...
...
src/llm/ernie_sdk.py
View file @
493cdd59
from
dataclasses
import
asdict
,
dataclass
from
typing
import
List
...
...
@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field
from
enum
import
Enum
class
MessageRole
(
str
,
Enum
):
USER
=
"user"
BOT
=
"assistant"
@dataclass
class
Message
:
role
:
str
content
:
str
@dataclass
class
CompletionRequest
:
messages
:
List
[
Message
]
stream
:
bool
=
False
user
:
str
=
""
@dataclass
class
Usage
:
prompt_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
@dataclass
class
CompletionResponse
:
id
:
str
...
...
@@ -42,12 +45,14 @@ class CompletionResponse:
is_safe
:
bool
=
False
is_truncated
:
bool
=
False
class
ErrorResponse
(
BaseModel
):
error_code
:
int
=
Field
(
...
)
error_msg
:
str
=
Field
(
...
)
id
:
str
=
Field
(
...
)
class
ErnieBot
():
class
ErnieBot
:
url
:
str
access_token
:
str
request
:
CompletionRequest
...
...
@@ -65,7 +70,7 @@ class ErnieBot():
headers
=
{
'Content-Type'
:
'application/json'
}
params
=
{
'access_token'
:
self
.
access_token
}
request_dict
=
asdict
(
self
.
request
)
response
=
requests
.
post
(
self
.
url
,
params
=
params
,
data
=
json
.
dumps
(
request_dict
),
headers
=
headers
)
response
=
requests
.
post
(
self
.
url
,
params
=
params
,
data
=
json
.
dumps
(
request_dict
),
headers
=
headers
)
# print(response.json())
try
:
return
CompletionResponse
(
**
response
.
json
())
...
...
@@ -73,8 +78,10 @@ class ErnieBot():
print
(
e
)
raise
Exception
(
response
.
json
())
def
user_message
(
prompt
:
str
)
->
Message
:
return
Message
(
MessageRole
.
USER
,
prompt
)
def
bot_message
(
prompt
:
str
)
->
Message
:
return
Message
(
MessageRole
.
BOT
,
prompt
)
src/llm/ernie_with_sdk.py
View file @
493cdd59
import
os
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
qianfan
from
qianfan
import
ChatCompletion
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class
ChatERNIESerLLM
(
LLM
):
# 模型服务url
chat_completion
:
ChatCompletion
=
None
chat_completion
:
ChatCompletion
=
None
# url: str = "http://127.0.0.1:8000"
chat_history
:
dict
=
[]
out_stream
:
bool
=
False
cache
:
bool
=
False
model_name
:
str
=
"ERNIE-Bot"
model_name
:
str
=
"ERNIE-Bot"
# def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
...
...
@@ -32,20 +33,19 @@ class ChatERNIESerLLM(LLM):
def
get_num_tokens
(
self
,
text
:
str
)
->
int
:
return
len
(
text
)
def
convert_data
(
self
,
data
):
@staticmethod
def
convert_data
(
data
):
result
=
[]
for
item
in
data
:
result
.
append
({
'q'
:
item
[
0
],
'a'
:
item
[
1
]})
return
result
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
AsyncCallbackManagerForLLMRun
]
=
None
,
stream
=
False
,
stream
=
False
,
**
kwargs
:
Any
)
->
str
:
resp
=
self
.
chat_completion
.
do
(
model
=
self
.
model_name
,
messages
=
[{
resp
=
self
.
chat_completion
.
do
(
model
=
self
.
model_name
,
messages
=
[{
"role"
:
"user"
,
"content"
:
prompt
}])
...
...
@@ -59,11 +59,12 @@ class ChatERNIESerLLM(LLM):
stream
=
False
)
->
Any
:
"""POST请求
"""
async
for
r
in
await
self
.
chat_completion
.
ado
(
model
=
self
.
model_name
,
messages
=
[
query
],
stream
=
stream
):
async
for
r
in
await
self
.
chat_completion
.
ado
(
model
=
self
.
model_name
,
messages
=
[
query
],
stream
=
stream
):
assert
r
.
code
==
200
if
run_manager
:
for
callable
in
run_manager
.
get_sync
()
.
handlers
:
await
callable
.
on_llm_new_token
(
r
.
body
[
"result"
])
for
_callable
in
run_manager
.
get_sync
()
.
handlers
:
await
_callable
.
on_llm_new_token
(
r
.
body
[
"result"
])
async
def
_acall
(
self
,
prompt
:
str
,
...
...
@@ -74,6 +75,5 @@ class ChatERNIESerLLM(LLM):
await
self
.
_post_stream
(
query
=
{
"role"
:
"user"
,
"content"
:
prompt
},
stream
=
True
,
run_manager
=
run_manager
)
},
stream
=
True
,
run_manager
=
run_manager
)
return
''
\ No newline at end of file
src/llm/loader.py
View file @
493cdd59
...
...
@@ -4,6 +4,7 @@ import torch
from
transformers
import
AutoModel
,
AutoTokenizer
,
AutoConfig
,
DataCollatorForSeq2Seq
from
peft
import
PeftModel
class
ModelLoader
:
def
__init__
(
self
,
model_name_or_path
,
pre_seq_len
=
0
,
prefix_projection
=
False
):
self
.
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
...
...
@@ -27,14 +28,15 @@ class ModelLoader:
def
collator
(
self
):
return
DataCollatorForSeq2Seq
(
tokenizer
=
self
.
tokenizer
,
model
=
self
.
model
)
def
load_lora
(
self
,
ckpt_path
,
name
=
"default"
):
#训练时节约GPU占用
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
ckpt_path
,
adapter_name
=
name
)
self
.
model
=
peft_loaded
.
merge_and_unload
()
def
load_lora
(
self
,
ckpt_path
,
name
=
"default"
):
#
训练时节约GPU占用
_peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
ckpt_path
,
adapter_name
=
name
)
self
.
model
=
_
peft_loaded
.
merge_and_unload
()
print
(
f
"Load LoRA model successfully!"
)
def
load_loras
(
self
,
ckpt_paths
,
name
=
"default"
):
if
len
(
ckpt_paths
)
==
0
:
def
load_loras
(
self
,
ckpt_paths
,
name
=
"default"
):
global
peft_loaded
if
len
(
ckpt_paths
)
==
0
:
return
first
=
True
for
name
,
path
in
ckpt_paths
.
items
():
...
...
@@ -43,11 +45,11 @@ class ModelLoader:
peft_loaded
=
PeftModel
.
from_pretrained
(
self
.
base_model
,
path
,
adapter_name
=
name
)
first
=
False
else
:
peft_loaded
.
load_adapter
(
path
,
adapter_name
=
name
)
peft_loaded
.
load_adapter
(
path
,
adapter_name
=
name
)
peft_loaded
.
set_adapter
(
name
)
self
.
model
=
peft_loaded
def
load_prefix
(
self
,
ckpt_path
):
def
load_prefix
(
self
,
ckpt_path
):
prefix_state_dict
=
torch
.
load
(
os
.
path
.
join
(
ckpt_path
,
"pytorch_model.bin"
))
new_prefix_state_dict
=
{}
for
k
,
v
in
prefix_state_dict
.
items
():
...
...
@@ -56,4 +58,3 @@ class ModelLoader:
self
.
model
.
transformer
.
prefix_encoder
.
load_state_dict
(
new_prefix_state_dict
)
self
.
model
.
transformer
.
prefix_encoder
.
float
()
print
(
f
"Load prefix model successfully!"
)
src/llm/spark.py
View file @
493cdd59
...
...
@@ -2,7 +2,7 @@ import logging
import
os
from
typing
import
Any
,
Dict
,
List
,
Mapping
,
Optional
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.schema
import
LLMResult
from
langchain.utils
import
get_from_dict_or_env
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
...
...
@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI
logger
=
logging
.
getLogger
(
__name__
)
text
=
[]
text
=
[]
# length = 0
def
getText
(
role
,
content
):
jsoncon
=
{}
jsoncon
[
"role"
]
=
role
jsoncon
[
"content"
]
=
content
def
getText
(
role
,
content
):
jsoncon
=
{
"role"
:
role
,
"content"
:
content
}
text
.
append
(
jsoncon
)
return
text
def
getlength
(
text
):
length
=
0
for
content
in
text
:
...
...
@@ -36,10 +35,11 @@ def getlength(text):
length
+=
leng
return
length
def
checklen
(
text
):
while
(
getlength
(
text
)
>
8000
):
del
text
[
0
]
return
text
def
checklen
(
_text
):
while
getlength
(
_text
)
>
8000
:
del
_text
[
0
]
return
_text
class
SparkLLM
(
LLM
):
...
...
@@ -68,7 +68,7 @@ class SparkLLM(LLM):
)
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
# print(values)
...
...
@@ -89,18 +89,14 @@ class SparkLLM(LLM):
values
[
"api_key"
]
=
api_key
values
[
"api_secret"
]
=
api_secret
api
=
SparkAPI
(
appid
,
api_key
,
api_secret
,
version
)
values
[
"api"
]
=
api
api
=
SparkAPI
(
appid
,
api_key
,
api_secret
,
version
)
values
[
"api"
]
=
api
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
question
=
self
.
getText
(
"user"
,
prompt
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
question
=
self
.
getText
(
"user"
,
prompt
)
try
:
# 你的代码
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
...
...
@@ -109,10 +105,10 @@ class SparkLLM(LLM):
return
response
except
Exception
as
e
:
# 处理异常
print
(
"exception:"
,
e
)
print
(
"exception:"
,
e
)
raise
e
def
getText
(
self
,
role
,
content
):
def
getText
(
self
,
role
,
content
):
text
=
[]
jsoncon
=
{}
jsoncon
[
"role"
]
=
role
...
...
@@ -124,5 +120,3 @@ class SparkLLM(LLM):
def
_llm_type
(
self
)
->
str
:
"""Return type of llm."""
return
"xinghuo"
\ No newline at end of file
src/llm/wrapper.py
View file @
493cdd59
from
langchain.llms.base
import
LLM
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
from
pydantic
import
root_validator
from
typing
import
Dict
,
List
,
Optional
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
class
WrapperLLM
(
LLM
):
tokenizer
:
PreTrainedTokenizer
=
None
model
:
PreTrainedModel
=
None
@root_validator
()
def
validate_environment
(
cls
,
values
:
Dict
)
->
Dict
:
def
validate_environment
(
self
,
values
:
Dict
)
->
Dict
:
"""Validate the environment."""
# print(values)
if
values
.
get
(
"model"
)
is
None
:
...
...
@@ -19,13 +19,9 @@ class WrapperLLM(LLM):
raise
ValueError
(
"No tokenizer provided."
)
return
values
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
def
_call
(
self
,
prompt
:
str
,
stop
:
Optional
[
List
[
str
]]
=
None
,
run_manager
:
Optional
[
CallbackManagerForLLMRun
]
=
None
,
**
kwargs
)
->
str
:
resp
,
his
=
self
.
model
.
chat
(
self
.
tokenizer
,
prompt
)
return
resp
@property
...
...
src/loader/callback.py
View file @
493cdd59
from
abc
import
ABC
,
abstractmethod
class
BaseCallback
(
ABC
):
@abstractmethod
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
#
return True舍弃当前段落
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
#
return True舍弃当前段落
pass
src/loader/chinese_text_splitter.py
View file @
493cdd59
...
...
@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter):
ele_id
=
ele1_ls
.
index
(
ele_ele1
)
ele1_ls
=
ele1_ls
[:
ele_id
]
+
[
i
for
i
in
ele2_ls
if
i
]
+
ele1_ls
[
ele_id
+
1
:]
id
=
ls
.
index
(
ele
)
ls
=
ls
[:
id
]
+
[
i
for
i
in
ele1_ls
if
i
]
+
ls
[
id
+
1
:]
_
id
=
ls
.
index
(
ele
)
ls
=
ls
[:
_id
]
+
[
i
for
i
in
ele1_ls
if
i
]
+
ls
[
_
id
+
1
:]
return
ls
src/loader/config.py
View file @
493cdd59
src/loader/load.py
View file @
493cdd59
This diff is collapsed.
Click to expand it.
src/loader/zh_title_enhance.py
View file @
493cdd59
...
...
@@ -33,7 +33,7 @@ def is_possible_title(
title_max_word_length
:
int
=
20
,
non_alpha_threshold
:
float
=
0.5
,
)
->
bool
:
"""Checks to see if the text passes all
of
the checks for a valid title.
"""Checks to see if the text passes all the checks for a valid title.
Parameters
----------
...
...
src/pgdb/chat/c_db.py
View file @
493cdd59
import
psycopg2
from
psycopg2
import
OperationalError
,
InterfaceError
class
UPostgresDB
:
'''
"""
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
...
...
@@ -18,8 +19,9 @@ class UPostgresDB:
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
'''
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
"""
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
self
.
host
=
host
self
.
database
=
database
self
.
user
=
user
...
...
@@ -35,7 +37,7 @@ class UPostgresDB:
database
=
self
.
database
,
user
=
self
.
user
,
password
=
self
.
password
,
port
=
self
.
port
port
=
self
.
port
)
self
.
cur
=
self
.
conn
.
cursor
()
except
Exception
as
e
:
...
...
@@ -89,7 +91,6 @@ class UPostgresDB:
print
(
f
"重新执行sql语句再次出现错误: {type(e).__name__}: {e}"
)
self
.
conn
.
rollback
()
def
search
(
self
,
query
,
params
=
None
):
if
self
.
conn
is
None
or
self
.
conn
.
closed
:
self
.
connect
()
...
...
src/pgdb/chat/c_user_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
import
json
TABLE_USER
=
"""
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
...
...
@@ -13,13 +14,14 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
"""
class
CUser
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO c_user(user_id, account, password) VALUES (
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
]
))
def
create_table
(
self
):
query
=
TABLE_USER
...
...
src/pgdb/chat/chat_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
import
json
TABLE_CHAT
=
"""
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
...
...
@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
"""
class
Chat
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
...
...
@@ -24,7 +26,7 @@ class Chat:
# 插入数据
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (
%
s,
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
]
))
# 创建表
def
create_table
(
self
):
...
...
src/pgdb/chat/turn_qa_table.py
View file @
493cdd59
from
.c_db
import
UPostgresDB
import
json
TABLE_CHAT
=
"""
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
...
...
@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
"""
class
TurnQa
:
def
__init__
(
self
,
db
:
UPostgresDB
)
->
None
:
self
.
db
=
db
...
...
@@ -28,7 +30,7 @@ class TurnQa:
# 插入数据
def
insert
(
self
,
value
):
query
=
f
"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (
%
s,
%
s,
%
s,
%
s,
%
s,
%
s)"
self
.
db
.
execute_args
(
query
,
(
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
],
value
[
4
],
value
[
5
])
))
self
.
db
.
execute_args
(
query
,
(
value
[
0
],
value
[
1
],
value
[
2
],
value
[
3
],
value
[
4
],
value
[
5
]
))
# 创建表
def
create_table
(
self
):
...
...
src/pgdb/knowledge/callback.py
View file @
493cdd59
...
...
@@ -4,22 +4,24 @@ from os import path
sys
.
path
.
append
(
"../"
)
from
abc
import
ABC
,
abstractmethod
import
json
from
typing
import
List
,
Any
,
Tuple
,
Dict
from
typing
import
List
,
Any
,
Tuple
,
Dict
from
langchain.schema
import
Document
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
,
str2hash_base64
from
src.pgdb.knowledge.pgsqldocstore
import
PgSqlDocstore
,
str2hash_base64
class
DocumentCallback
(
ABC
):
@abstractmethod
#向量库储存前文档处理--
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
@abstractmethod
#
向量库储存前文档处理--
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
pass
@abstractmethod
#向量库查询后文档处理--用于结构建立
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
List
[
Tuple
[
Document
,
float
]]:
#向量库查询后文档处理
@abstractmethod
# 向量库查询后文档处理--用于结构建立
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
\
List
[
Tuple
[
Document
,
float
]]:
# 向量库查询后文档处理
pass
class
DefaultDocumentCallback
(
DocumentCallback
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
def
before_store
(
self
,
docstore
:
PgSqlDocstore
,
documents
):
output_doc
=
[]
for
doc
in
documents
:
if
"next_doc"
in
doc
.
metadata
:
...
...
@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback):
doc
.
metadata
.
pop
(
"next_doc"
)
output_doc
.
append
(
doc
)
return
output_doc
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
List
[
Tuple
[
Document
,
float
]]:
#向量库查询后文档处理
output_doc
:
List
[
Tuple
[
Document
,
float
]]
=
[]
def
after_search
(
self
,
docstore
:
PgSqlDocstore
,
documents
:
List
[
Tuple
[
Document
,
float
]],
number
:
int
=
1000
)
->
\
List
[
Tuple
[
Document
,
float
]]:
# 向量库查询后文档处理
output_doc
:
List
[
Tuple
[
Document
,
float
]]
=
[]
exist_hash
=
[]
for
doc
,
score
in
documents
:
for
doc
,
score
in
documents
:
print
(
exist_hash
)
dochash
=
str2hash_base64
(
doc
.
page_content
)
if
dochash
in
exist_hash
:
continue
else
:
exist_hash
.
append
(
dochash
)
output_doc
.
append
((
doc
,
score
))
output_doc
.
append
((
doc
,
score
))
if
len
(
output_doc
)
>
number
:
return
output_doc
fordoc
=
doc
while
(
"next_hash"
in
fordoc
.
metadata
)
:
if
len
(
fordoc
.
metadata
[
"next_hash"
])
>
0
:
while
"next_hash"
in
fordoc
.
metadata
:
if
len
(
fordoc
.
metadata
[
"next_hash"
])
>
0
:
if
fordoc
.
metadata
[
"next_hash"
]
in
exist_hash
:
break
else
:
...
...
@@ -50,7 +54,7 @@ class DefaultDocumentCallback(DocumentCallback):
content
=
docstore
.
TXT_DOC
.
search
(
fordoc
.
metadata
[
"next_hash"
])
if
content
:
fordoc
=
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
output_doc
.
append
((
fordoc
,
score
))
output_doc
.
append
((
fordoc
,
score
))
if
len
(
output_doc
)
>
number
:
return
output_doc
else
:
...
...
src/pgdb/knowledge/k_db.py
View file @
493cdd59
import
psycopg2
class
PostgresDB
:
'''
"""
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
...
...
@@ -17,8 +18,9 @@ class PostgresDB:
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
'''
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
"""
def
__init__
(
self
,
host
,
database
,
user
,
password
,
port
=
5432
):
self
.
host
=
host
self
.
database
=
database
self
.
user
=
user
...
...
@@ -33,7 +35,7 @@ class PostgresDB:
database
=
self
.
database
,
user
=
self
.
user
,
password
=
self
.
password
,
port
=
self
.
port
port
=
self
.
port
)
self
.
cur
=
self
.
conn
.
cursor
()
...
...
@@ -44,6 +46,7 @@ class PostgresDB:
except
Exception
as
e
:
print
(
f
"An error occurred: {e}"
)
self
.
conn
.
rollback
()
def
execute_args
(
self
,
query
,
args
):
try
:
self
.
cur
.
execute
(
query
,
args
)
...
...
src/pgdb/knowledge/pgsqldocstore.py
View file @
493cdd59
import
sys
from
os
import
path
# 这里相当于把当前目录添加到pythonpath中
sys
.
path
.
append
(
path
.
dirname
(
path
.
abspath
(
__file__
)))
from
typing
import
List
,
Union
,
Dict
,
Optional
from
typing
import
List
,
Union
,
Dict
,
Optional
from
langchain.docstore.base
import
AddableMixin
,
Docstore
from
k_db
import
PostgresDB
from
.txt_doc_table
import
TxtDoc
from
.vec_txt_table
import
TxtVector
import
json
,
hashlib
,
base64
import
json
,
hashlib
,
base64
from
langchain.schema
import
Document
def
str2hash_base64
(
input
:
str
)
->
str
:
def
str2hash_base64
(
inp
:
str
)
->
str
:
# return f"%s" % hash(input)
return
base64
.
b64encode
(
hashlib
.
sha1
(
inp
ut
.
encode
())
.
digest
())
.
decode
()
return
base64
.
b64encode
(
hashlib
.
sha1
(
inp
.
encode
())
.
digest
())
.
decode
()
class
PgSqlDocstore
(
Docstore
,
AddableMixin
):
host
:
str
dbname
:
str
username
:
str
password
:
str
port
:
str
class
PgSqlDocstore
(
Docstore
,
AddableMixin
):
host
:
str
dbname
:
str
username
:
str
password
:
str
port
:
str
'''
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
'''
def
__getstate__
(
self
):
return
{
"host"
:
self
.
host
,
"dbname"
:
self
.
dbname
,
"username"
:
self
.
username
,
"password"
:
self
.
password
,
"port"
:
self
.
port
}
return
{
"host"
:
self
.
host
,
"dbname"
:
self
.
dbname
,
"username"
:
self
.
username
,
"password"
:
self
.
password
,
"port"
:
self
.
port
}
def
__setstate__
(
self
,
info
):
self
.
__init__
(
info
)
def
__init__
(
self
,
info
:
dict
,
reset
:
bool
=
False
):
def
__init__
(
self
,
info
:
dict
,
reset
:
bool
=
False
):
self
.
host
=
info
[
"host"
]
self
.
dbname
=
info
[
"dbname"
]
self
.
username
=
info
[
"username"
]
self
.
password
=
info
[
"password"
]
self
.
port
=
info
[
"port"
]
if
"port"
in
info
else
"5432"
;
self
.
pgdb
=
PostgresDB
(
self
.
host
,
self
.
dbname
,
self
.
username
,
self
.
password
,
port
=
self
.
port
)
self
.
port
=
info
[
"port"
]
if
"port"
in
info
else
"5432"
self
.
pgdb
=
PostgresDB
(
self
.
host
,
self
.
dbname
,
self
.
username
,
self
.
password
,
port
=
self
.
port
)
self
.
TXT_DOC
=
TxtDoc
(
self
.
pgdb
)
self
.
VEC_TXT
=
TxtVector
(
self
.
pgdb
)
if
reset
:
...
...
@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin):
self
.
VEC_TXT
.
drop_table
()
self
.
TXT_DOC
.
create_table
()
self
.
VEC_TXT
.
create_table
()
def
__sub_init__
(
self
):
if
not
self
.
pgdb
.
conn
:
self
.
pgdb
.
connect
()
'''
从本地库中查找向量对应的文本段落,封装成Document返回
'''
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
...
...
@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin):
return
Document
(
page_content
=
content
[
0
],
metadata
=
json
.
loads
(
content
[
1
]))
else
:
return
Document
()
'''
从本地库中删除向量对应的文本,批量删除
'''
def
delete
(
self
,
ids
:
List
)
->
None
:
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
pids
=
[]
for
i
d
in
ids
:
anwser
=
self
.
VEC_TXT
.
search
(
i
d
)
for
i
tem
in
ids
:
anwser
=
self
.
VEC_TXT
.
search
(
i
tem
)
pids
.
append
(
anwser
[
0
])
self
.
VEC_TXT
.
delete
(
ids
)
self
.
TXT_DOC
.
delete
(
pids
)
'''
向本地库添加向量和文本信息
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
'''
def
add
(
self
,
texts
:
Dict
[
str
,
Document
])
->
None
:
# for vec,doc in texts.items():
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
if
not
self
.
pgdb
.
conn
:
self
.
__sub_init__
()
paragraph_hashs
=
[]
#
hash,text
paragraph_hashs
=
[]
#
hash,text
paragraph_txts
=
[]
vec_inserts
=
[]
for
vec
,
doc
in
texts
.
items
():
for
vec
,
doc
in
texts
.
items
():
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
print
(
txt_hash
)
vec_inserts
.
append
((
vec
,
doc
.
page_content
,
txt_hash
))
vec_inserts
.
append
((
vec
,
doc
.
page_content
,
txt_hash
))
if
txt_hash
not
in
paragraph_hashs
:
paragraph_hashs
.
append
(
txt_hash
)
paragraph
=
doc
.
metadata
[
"paragraph"
]
doc
.
metadata
.
pop
(
"paragraph"
)
paragraph_txts
.
append
((
txt_hash
,
paragraph
,
json
.
dumps
(
doc
.
metadata
,
ensure_ascii
=
False
)))
paragraph_txts
.
append
((
txt_hash
,
paragraph
,
json
.
dumps
(
doc
.
metadata
,
ensure_ascii
=
False
)))
# print(paragraph_txts)
self
.
TXT_DOC
.
insert
(
paragraph_txts
)
self
.
VEC_TXT
.
insert
(
vec_inserts
)
...
...
@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin):
class
InMemorySecondaryDocstore
(
Docstore
,
AddableMixin
):
"""Simple in memory docstore in the form of a dict."""
def
__init__
(
self
,
_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
,
_sec_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
):
def
__init__
(
self
,
_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
,
_sec_dict
:
Optional
[
Dict
[
str
,
Document
]]
=
None
):
"""Initialize with dict."""
self
.
_dict
=
_dict
if
_dict
is
not
None
else
{}
self
.
_sec_dict
=
_sec_dict
if
_sec_dict
is
not
None
else
{}
...
...
@@ -126,14 +135,14 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
dict1
=
{}
dict_sec
=
{}
for
vec
,
doc
in
texts
.
items
():
for
vec
,
doc
in
texts
.
items
():
txt_hash
=
str2hash_base64
(
doc
.
metadata
[
"paragraph"
])
metadata
=
doc
.
metadata
metadata
=
doc
.
metadata
paragraph
=
metadata
.
pop
(
'paragraph'
)
# metadata.update({"paragraph_id":txt_hash})
metadata
[
'paragraph_id'
]
=
txt_hash
dict_sec
[
txt_hash
]
=
Document
(
page_content
=
paragraph
,
metadata
=
metadata
)
dict1
[
vec
]
=
Document
(
page_content
=
doc
.
page_content
,
metadata
=
{
'paragraph_id'
:
txt_hash
})
metadata
[
'paragraph_id'
]
=
txt_hash
dict_sec
[
txt_hash
]
=
Document
(
page_content
=
paragraph
,
metadata
=
metadata
)
dict1
[
vec
]
=
Document
(
page_content
=
doc
.
page_content
,
metadata
=
{
'paragraph_id'
:
txt_hash
})
self
.
_dict
=
{
**
self
.
_dict
,
**
dict1
}
self
.
_sec_dict
=
{
**
self
.
_sec_dict
,
**
dict_sec
}
...
...
@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if
not
overlapping
:
raise
ValueError
(
f
"Tried to delete ids that does not exist: {ids}"
)
for
_id
in
ids
:
self
.
_sec_dict
.
pop
(
self
.
_dict
[
id
]
.
metadata
[
'paragraph_id'
])
self
.
_sec_dict
.
pop
(
self
.
_dict
[
_
id
]
.
metadata
[
'paragraph_id'
])
self
.
_dict
.
pop
(
_id
)
def
search
(
self
,
search
:
str
)
->
Union
[
str
,
Document
]:
...
...
src/pgdb/knowledge/similarity.py
View file @
493cdd59
This diff is collapsed.
Click to expand it.
src/pgdb/knowledge/txt_doc_table.py
View file @
493cdd59
from
.k_db
import
PostgresDB
# paragraph_id BIGSERIAL primary key,
# paragraph_id BIGSERIAL primary key,
TABLE_TXT_DOC
=
"""
create table txt_doc (
hash varchar(40) primary key,
...
...
@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
"""
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
class
TxtDoc
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
...
...
@@ -21,19 +24,20 @@ class TxtDoc:
args
=
[]
for
value
in
texts
:
value
=
list
(
value
)
query
+=
"(
%
s,
%
s,
%
s),"
query
+=
"(
%
s,
%
s,
%
s),"
args
.
extend
(
value
)
query
=
query
[:
len
(
query
)
-
1
]
query
=
query
[:
len
(
query
)
-
1
]
query
+=
f
"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
self
.
db
.
execute_args
(
query
,
args
)
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
i
d
in
ids
:
query
=
f
"delete FROM txt_doc WHERE hash =
%
s"
%
(
id
)
def
delete
(
self
,
ids
):
for
i
tem
in
ids
:
query
=
f
"delete FROM txt_doc WHERE hash =
%
s"
%
item
self
.
db
.
execute
(
query
)
def
search
(
self
,
id
):
def
search
(
self
,
item
):
query
=
"SELECT text,matadate FROM txt_doc WHERE hash =
%
s"
self
.
db
.
execute_args
(
query
,
[
id
])
self
.
db
.
execute_args
(
query
,
[
item
])
answer
=
self
.
db
.
fetchall
()
if
len
(
answer
)
>
0
:
return
answer
[
0
]
...
...
@@ -60,4 +64,3 @@ class TxtDoc:
query
=
"DROP TABLE txt_doc"
self
.
db
.
format
(
query
)
print
(
"drop table txt_doc ok"
)
src/pgdb/knowledge/vec_txt_table.py
View file @
493cdd59
from
.k_db
import
PostgresDB
TABLE_VEC_TXT
=
"""
CREATE TABLE vec_txt (
vector_id varchar(36) PRIMARY KEY,
...
...
@@ -6,7 +7,9 @@ CREATE TABLE vec_txt (
paragraph_id varchar(40) not null
)
"""
#025a9bee-2eb2-47f5-9722-525e05a0442b
# 025a9bee-2eb2-47f5-9722-525e05a0442b
class
TxtVector
:
def
__init__
(
self
,
db
:
PostgresDB
)
->
None
:
self
.
db
=
db
...
...
@@ -16,19 +19,21 @@ class TxtVector:
args
=
[]
for
value
in
vectors
:
value
=
list
(
value
)
query
+=
"(
%
s,
%
s,
%
s),"
query
+=
"(
%
s,
%
s,
%
s),"
args
.
extend
(
value
)
query
=
query
[:
len
(
query
)
-
1
]
query
=
query
[:
len
(
query
)
-
1
]
query
+=
f
"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
# query += ";"
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
id
in
ids
:
query
=
f
"delete FROM vec_txt WHERE vector_id = '
%
s'"
%
(
id
,)
self
.
db
.
execute_args
(
query
,
args
)
def
delete
(
self
,
ids
):
for
item
in
ids
:
query
=
f
"delete FROM vec_txt WHERE vector_id = '
%
s'"
%
(
item
,)
self
.
db
.
execute
(
query
)
def
search
(
self
,
search
:
str
):
query
=
f
"SELECT paragraph_id,text FROM vec_txt WHERE vector_id =
%
s"
self
.
db
.
execute_args
(
query
,[
search
])
self
.
db
.
execute_args
(
query
,
[
search
])
answer
=
self
.
db
.
fetchall
()
print
(
answer
)
return
answer
[
0
]
...
...
test/chat_table_test.py
View file @
493cdd59
import
sys
sys
.
path
.
append
(
"../"
)
from
src.pgdb.chat.c_db
import
UPostgresDB
from
src.pgdb.chat.chat_table
import
Chat
from
src.pgdb.chat.c_user_table
import
CUser
from
src.pgdb.chat.turn_qa_table
import
TurnQa
"""测试会话相关数据可的连接"""
c_db
=
UPostgresDB
(
host
=
"localhost"
,
database
=
"laechat"
,
user
=
"postgres"
,
password
=
"chenzl"
,
port
=
5432
)
chat
=
Chat
(
db
=
c_db
)
c_user
=
CUser
(
db
=
c_db
)
turn_qa
=
TurnQa
(
db
=
c_db
)
chat
.
create_table
()
c_user
.
create_table
()
turn_qa
.
create_table
()
# chat_id, user_id, info, deleted
chat
.
insert
([
"3333"
,
"1111"
,
"没有info"
,
0
])
def
test
():
c_db
=
UPostgresDB
(
host
=
"localhost"
,
database
=
"laechat"
,
user
=
"postgres"
,
password
=
"chenzl"
,
port
=
5432
)
chat
=
Chat
(
db
=
c_db
)
c_user
=
CUser
(
db
=
c_db
)
turn_qa
=
TurnQa
(
db
=
c_db
)
chat
.
create_table
()
c_user
.
create_table
()
turn_qa
.
create_table
()
# chat_id, user_id, info, deleted
chat
.
insert
([
"3333"
,
"1111"
,
"没有info"
,
0
])
# user_id, account, password
c_user
.
insert
([
"111"
,
"zhangsan"
,
"111111"
])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa
.
insert
([
"222"
,
"1111"
,
"nihao"
,
"nihao"
,
1
,
0
])
# user_id, account, password
c_user
.
insert
([
"111"
,
"zhangsan"
,
"111111"
])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa
.
insert
([
"222"
,
"1111"
,
"nihao"
,
"nihao"
,
1
,
0
])
\ No newline at end of file
if
__name__
==
"main"
:
test
()
test/k_store_test.py
View file @
493cdd59
import
sys
sys
.
path
.
append
(
"../"
)
import
time
from
src.loader.load
import
loads_path
,
loads
from
src.loader.load
import
loads_path
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.config.consts
import
(
VEC_DB_DBNAME
,
...
...
@@ -18,24 +18,27 @@ from src.config.consts import (
from
src.loader.callback
import
BaseCallback
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class
localCallback
(
BaseCallback
):
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
if
len
(
title
+
content
)
==
0
:
def
filter
(
self
,
title
:
str
,
content
:
str
)
->
bool
:
if
len
(
title
+
content
)
==
0
:
return
True
return
(
len
(
title
+
content
)
/
(
len
(
title
.
splitlines
())
+
len
(
content
.
splitlines
()))
<
20
)
or
"思考题"
in
title
return
(
len
(
title
+
content
)
/
(
len
(
title
.
splitlines
())
+
len
(
content
.
splitlines
()))
<
20
)
or
"思考题"
in
title
"""测试资料入库(pgsql和faiss)"""
def
test_faiss_from_dir
():
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
True
)
docs
=
loads_path
(
KNOWLEDGE_PATH
,
mode
=
"paged"
,
sentence_size
=
512
,
callbacks
=
[
localCallback
()])
docs
=
loads_path
(
KNOWLEDGE_PATH
,
mode
=
"paged"
,
sentence_size
=
512
,
callbacks
=
[
localCallback
()])
print
(
len
(
docs
))
last_doc
=
None
docs1
=
[]
...
...
@@ -45,7 +48,8 @@ def test_faiss_from_dir():
continue
if
"font-size"
not
in
doc
.
metadata
or
"page_number"
not
in
doc
.
metadata
:
continue
if
doc
.
metadata
[
"font-size"
]
==
last_doc
.
metadata
[
"font-size"
]
and
doc
.
metadata
[
"page_number"
]
==
last_doc
.
metadata
[
"page_number"
]
and
len
(
doc
.
page_content
)
+
len
(
last_doc
.
page_content
)
<
512
/
4
*
3
:
if
doc
.
metadata
[
"font-size"
]
==
last_doc
.
metadata
[
"font-size"
]
and
doc
.
metadata
[
"page_number"
]
==
\
last_doc
.
metadata
[
"page_number"
]
and
len
(
doc
.
page_content
)
+
len
(
last_doc
.
page_content
)
<
512
/
4
*
3
:
last_doc
.
page_content
+=
doc
.
page_content
else
:
docs1
.
append
(
last_doc
)
...
...
@@ -56,17 +60,21 @@ def test_faiss_from_dir():
print
(
len
(
docs
))
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
for
i
in
range
(
0
,
len
(
docs
),
300
):
vecstore_faiss
.
_add_documents
(
docs
[
i
:
i
+
300
if
i
+
300
<
len
(
docs
)
else
len
(
docs
)],
need_split
=
True
)
vecstore_faiss
.
_add_documents
(
docs
[
i
:
i
+
300
if
i
+
300
<
len
(
docs
)
else
len
(
docs
)],
need_split
=
True
)
print
(
vecstore_faiss
.
_faiss
.
index
.
ntotal
)
vecstore_faiss
.
_save_local
()
"""测试faiss向量数据库查询结果"""
def
test_faiss_load
():
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
print
(
vecstore_faiss
.
_join_document
(
vecstore_faiss
.
get_text_similarity
(
"征信业务有什么情况"
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment