Commit bfd1cd58 by 陈正乐

知识库和向量库搭建以及文档入库实现

parent 669c9c76
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
dist/
eggs/
*.egg-info/
bin/
include/
lib/
local/
man/
share/
pip-wheel-metadata/
htmlcov/
.coverage
.tox/
.pytest_cache/
pytest.ini
# PyCharm
.idea/
# VSCode
.vscode/
# Jupyter Notebook
.ipynb_checkpoints
# Django
*.log
*.pot
*.pyc
local_settings.py
db.sqlite3
db.sqlite3-journal
media
# Flask
instance/
.webassets-cache
# Sphinx documentation
docs/_build/
model/
data/
exam/
aaa/
bbb/
ccc/
.env
faiss/
\ No newline at end of file
bitsandbytes==0.41.1
cpm-kernels==1.0.11
fastapi==0.100.0
Flask==2.3.2
jieba==0.42.1
langchain==0.0.278
peft==0.4.0
psycopg2==2.9.7
pydantic==1.10.12
requests==2.31.0
sentence-transformers==2.2.2
torch==2.0.1
transformers==4.31.0
uvicorn==0.23.1
unstructured==0.8.1
bs4==0.0.1
mdtex2html==1.2.0
faiss-gpu==1.7.2 # https://github.com/facebookresearch/faiss/blob/main/INSTALL.md
\ No newline at end of file
"""配置相关"""
\ No newline at end of file
VEC_DB_HOST = 'localhost'
VEC_DB_DBNAME='lae'
VEC_DB_USER='postgres'
VEC_DB_PASSWORD='chenzl'
VEC_DB_PORT='5432'
EMBEEDING_MODEL_PATH = 'C:\\Users\\15663\\AI\\models\\bge-large-zh-v1.5'
LLM_SERVER_URL = '192.168.10.102:8002'
SIMILARITY_SHOW_NUMBER = 5
SIMILARITY_THRESHOLD = 0.8
FAISS_STORE_PATH = '../faiss'
KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\work\\llm_gjjs\\兴火燎原知识库\\兴火燎原知识库\\law\\pdf'
INDEX_NAME = 'know'
\ No newline at end of file
"""各种大模型提供的服务"""
\ No newline at end of file
import os
from typing import Dict, Optional,List
from langchain.llms.base import BaseLLM,LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig,AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig
from pydantic import root_validator
class BaichuanLLM(LLM):
model_name: str = "baichuan-inc/Baichuan-13B-Chat"
quantization_bit: Optional[int] = None
tokenizer: AutoTokenizer = None
model: AutoModel = None
def _llm_type(self) -> str:
return "chatglm_local"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
if not values["model_name"]:
raise ValueError("No model name provided.")
model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=False,trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
# device_map="auto",
trust_remote_code=True
)
model.generation_config = GenerationConfig.from_pretrained(
model_name
)
if values["quantization_bit"]:
print(f"Quantized to {values['quantization_bit']} bit")
model = model.quantize(values["quantization_bit"]).cuda()
else:
model=model.half().cuda()
model = model.eval()
values["tokenizer"] = tokenizer
values["model"] = model
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
message = []
message.append({"role": "user", "content": prompt})
resp = self.model.chat(self.tokenizer,message)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp
\ No newline at end of file
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from pydantic import root_validator
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig
import langchain
from langchain.llms.base import BaseLLM,LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import aiohttp
import asyncio
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class ChatGLMLocLLM(LLM):
model_name: str = "THUDM/chatglm-6b"
ptuning_checkpoint: str = None
quantization_bit: Optional[int] = None
pre_seq_len: Optional[int] = None
prefix_projection: bool = False
tokenizer: AutoTokenizer = None
model: AutoModel = None
def _llm_type(self) -> str:
return "chatglm_local"
# @root_validator()
def validate_environment(cls, values: Dict) -> Dict:
if not values["model_name"]:
raise ValueError("No model name provided.")
model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name ,trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if values["pre_seq_len"]:
config.pre_seq_len = values["pre_seq_len"]
if values["prefix_projection"]:
config.prefix_projection = values["prefix_projection"]
if values["ptuning_checkpoint"]:
ptuning_checkpoint = values["ptuning_checkpoint"]
print(f"Loading prefix_encoder weight from {ptuning_checkpoint}")
model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(ptuning_checkpoint, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True).half().cuda()
if values["pre_seq_len"]:
# P-tuning v2
model = model.half().cuda()
model.transformer.prefix_encoder.float().cuda()
if values["quantization_bit"]:
print(f"Quantized to {values['quantization_bit']} bit")
model = model.quantize(values["quantization_bit"])
model = model.eval()
values["tokenizer"] = tokenizer
values["model"] = model
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
resp,his = self.model.chat(self.tokenizer,prompt)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp
class ChatGLMSerLLM(LLM):
# 模型服务url
url: str = "http://127.0.0.1:8000"
chat_history: dict = []
out_stream: bool = False
cache: bool = False
@property
def _llm_type(self) -> str:
return "chatglm3-6b"
def get_num_tokens(self, text: str) -> int:
resp = self._post(url=self.url+"/tokens",query=self._construct_query(text))
if resp.status_code == 200:
resp_json = resp.json()
predictions = resp_json['response']
# display(self.convert_data(resp_json['history']))
return predictions
else:
return len(text)
def convert_data(self,data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
def _construct_query(self, prompt: str,temperature = 0.95) -> Dict:
"""构造请求体
"""
# self.chat_history.append({"role": "user", "content": prompt})
query = {
"prompt": prompt,
"history":self.chat_history,
"max_length": 4096,
"top_p": 0.7,
"temperature": temperature
}
return query
@classmethod
def _post(self, url: str,
query: Dict) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
with requests.session() as sess:
resp = sess.post(url,
json=query,
headers=_headers,
timeout=300)
return resp
async def _post_stream(self, url: str,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,stream=False) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
async with aiohttp.ClientSession() as sess:
async with sess.post(url, json=query,headers=_headers,timeout=300) as response:
if response.status == 200:
if stream and not run_manager:
print('not callable')
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_start(None,None)
async for chunk in response.content.iter_any():
# 处理每个块的数据
if chunk and run_manager:
for callable in run_manager.get_sync().handlers:
# print(chunk.decode("utf-8"),end="")
await callable.on_llm_new_token(chunk.decode("utf-8"))
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_end(None)
else:
raise ValueError(f'glm 请求异常,http code:{response.status}')
def _call(self, prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream = False,
**kwargs: Any) -> str:
query = self._construct_query(prompt=prompt,temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
# display("==============================")
# display(query)
# post
if stream or self.out_stream:
async def _post_stream():
await self._post_stream(url=self.url+"/stream",
query=query,run_manager=run_manager,stream=stream or self.out_stream)
asyncio.run(_post_stream())
return ''
else:
resp = self._post(url=self.url,
query=query)
if resp.status_code == 200:
resp_json = resp.json()
# self.chat_history.append({'q': prompt, 'a': resp_json['response']})
predictions = resp_json['response']
# display(self.convert_data(resp_json['history']))
return predictions
else:
raise ValueError(f'glm 请求异常,http code:{resp.status_code}')
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
query = self._construct_query(prompt=prompt,temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
await self._post_stream(url=self.url+"/stream",
query=query,run_manager=run_manager,stream=self.out_stream)
return ''
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.
"""
_param_dict = {
"url": self.url
}
return _param_dict
\ No newline at end of file
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from pydantic import root_validator
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig
import langchain
from langchain.llms.base import BaseLLM,LLM
from langchain_openai import OpenAI
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
class ChatGLMSerLLM(OpenAI):
def get_token_ids(self, text: str) -> List[int]:
if self.model_name.__contains__("chatglm"):
## 发起http请求,获取token_ids
url = f"{self.openai_api_base}/num_tokens"
query = {"prompt": text,"model": self.model_name}
_headers = {"Content_Type": "application/json","Authorization": "chatglm "+self.openai_api_key}
resp = self._post(url=url,query=query,headers= _headers)
if resp.status_code == 200:
resp_json = resp.json()
print(resp_json)
predictions = resp_json['choices'][0]['text']
## predictions字符串转int
return [int(predictions)]
return [len(text)]
@classmethod
def _post(self, url: str,
query: Dict,headers: Dict) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
_headers.update(headers)
with requests.session() as sess:
resp = sess.post(url,
json=query,
headers=_headers,
timeout=300)
return resp
\ No newline at end of file
import logging
import os
from typing import Any, Dict, List, Mapping, Optional
from langchain.llms.base import BaseLLM,LLM
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
from enum import Enum
from pydantic import root_validator, Field
from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_message
logger = logging.getLogger(__name__)
class ModelType(Enum):
ERNIE = "ernie"
ERNIE_LITE = "ernie-lite"
SHEETS1 = "sheets1"
SHEETS2 = "sheets2"
SHEET_COMB = "sheet-comb"
LLAMA2_7B = "llama2-7b"
LLAMA2_13B = "llama2-13b"
LLAMA2_70B = "llama2-70b"
QFCN_LLAMA2_7B = "qfcn-llama2-7b"
BLOOMZ_7B="bloomz-7b"
MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/"
MODEL_SERVICE_Suffix = {
ModelType.ERNIE: "ai_custom/v1/wenxinworkshop/chat/completions",
ModelType.ERNIE_LITE: "ai_custom/v1/wenxinworkshop/chat/eb-instant",
ModelType.SHEETS1: "ai_custom/v1/wenxinworkshop/chat/besheet",
ModelType.SHEETS2: "ai_custom/v1/wenxinworkshop/chat/besheets2",
ModelType.SHEET_COMB: "ai_custom/v1/wenxinworkshop/chat/sheet_comb1",
ModelType.LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/llama_2_7b",
ModelType.LLAMA2_13B: "ai_custom/v1/wenxinworkshop/chat/llama_2_13b",
ModelType.LLAMA2_70B: "ai_custom/v1/wenxinworkshop/chat/llama_2_70b",
ModelType.QFCN_LLAMA2_7B: "ai_custom/v1/wenxinworkshop/chat/qianfan_chinese_llama_2_7b",
ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1",
}
class ErnieLLM(LLM):
"""
ErnieLLM is a LLM that uses Ernie to generate text.
"""
model_name: Optional[ModelType] = None
access_token: Optional[str] = ""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE)))
access_token = get_from_dict_or_env(values, "access_token", "ERNIE_ACCESS_TOKEN", "")
if not access_token:
raise ValueError("No access token provided.")
values["model_name"] = model_name
values["access_token"] = access_token
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
request = CompletionRequest(messages=[Message("user",prompt)])
bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request)
try:
# 你的代码
response = bot.get_response().result
# print("response: ",response)
return response
except Exception as e:
# 处理异常
print("exception:",e)
return e.__str__()
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ernie"
# def _identifying_params(self) -> Mapping[str, Any]:
# return {
# "name": "ernie",
# }
def _get_model_service_url(model_name) -> str:
# print("_get_model_service_url model_name: ",model_name)
return MODEL_SERVICE_BASE_URL+MODEL_SERVICE_Suffix[model_name]
class ErnieChat(LLM):
model_name: ModelType
access_token: str
prefix_messages: List = Field(default_factory=list)
id: str = ""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
msg = user_message(prompt)
request = CompletionRequest(messages=self.prefix_messages+[msg])
bot = ErnieBot(_get_model_service_url(self.model_name),self.access_token,request)
try:
# 你的代码
response = bot.get_response().result
if self.id == "":
self.id = bot.get_response().id
self.prefix_messages.append(msg)
self.prefix_messages.append(bot_message(response))
return response
except Exception as e:
# 处理异常
raise e
def _get_id(self) -> str:
return self.id
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ernie"
\ No newline at end of file
from dataclasses import asdict, dataclass
from typing import List
from pydantic import BaseModel, Field
from enum import Enum
class MessageRole(str, Enum):
USER = "user"
BOT = "assistant"
@dataclass
class Message:
role: str
content: str
@dataclass
class CompletionRequest:
messages: List[Message]
stream: bool = False
user: str = ""
@dataclass
class Usage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclass
class CompletionResponse:
id: str
object: str
created: int
result: str
need_clear_history: bool
ban_round: int = 0
sentence_id: int = 0
is_end: bool = False
usage: Usage = None
is_safe: bool = False
is_truncated: bool = False
class ErrorResponse(BaseModel):
error_code: int = Field(...)
error_msg: str = Field(...)
id: str = Field(...)
class ErnieBot():
url: str
access_token: str
request: CompletionRequest
def __init__(self, url: str, access_token: str, request: CompletionRequest):
self.url = url
self.access_token = access_token
self.request = request
def get_response(self) -> CompletionResponse:
import requests
import json
headers = {'Content-Type': 'application/json'}
params = {'access_token': self.access_token}
request_dict = asdict(self.request)
response = requests.post(self.url, params=params,data=json.dumps(request_dict), headers=headers)
# print(response.json())
try:
return CompletionResponse(**response.json())
except Exception as e:
print(e)
raise Exception(response.json())
def user_message(prompt: str) -> Message:
return Message(MessageRole.USER, prompt)
def bot_message(prompt: str) -> Message:
return Message(MessageRole.BOT, prompt)
\ No newline at end of file
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from pydantic import root_validator
from langchain.llms.base import BaseLLM,LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import qianfan
from qianfan import ChatCompletion
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class ChatERNIESerLLM(LLM):
# 模型服务url
chat_completion:ChatCompletion = None
# url: str = "http://127.0.0.1:8000"
chat_history: dict = []
out_stream: bool = False
cache: bool = False
model_name:str = "ERNIE-Bot"
# def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
@property
def _llm_type(self) -> str:
return self.model_name
def get_num_tokens(self, text: str) -> int:
return len(text)
def convert_data(self,data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
def _call(self, prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream = False,
**kwargs: Any) -> str:
resp = self.chat_completion.do(model=self.model_name,messages=[{
"role": "user",
"content": prompt
}])
print(resp)
assert resp.code == 200
return resp.body["result"]
async def _post_stream(self,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream=False) -> Any:
"""POST请求
"""
async for r in await self.chat_completion.ado(model=self.model_name,messages=[query], stream=stream):
assert r.code == 200
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_new_token(r.body["result"])
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
await self._post_stream(query={
"role": "user",
"content": prompt
},stream=True,run_manager=run_manager)
return ''
\ No newline at end of file
import os
import transformers
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig, DataCollatorForSeq2Seq
from peft import PeftModel
class ModelLoader:
def __init__(self, model_name_or_path, pre_seq_len=0, prefix_projection=False):
self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if pre_seq_len is not None and pre_seq_len > 0:
self.config.pre_seq_len = pre_seq_len
self.config.prefix_projection = prefix_projection
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(model_name_or_path, config=self.config, trust_remote_code=True).half()
# self.model = self.model.cuda()
self.base_model = self.model
def quantize(self, quantization_bit):
if quantization_bit is not None:
print(f"Quantized to {quantization_bit} bit")
self.model = self.model.quantize(quantization_bit)
return self.model
def models(self):
return self.model, self.tokenizer
def collator(self):
return DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model)
def load_lora(self,ckpt_path,name="default"):
#训练时节约GPU占用
peft_loaded = PeftModel.from_pretrained(self.base_model,ckpt_path,adapter_name=name)
self.model = peft_loaded.merge_and_unload()
print(f"Load LoRA model successfully!")
def load_loras(self,ckpt_paths,name="default"):
if len(ckpt_paths)==0:
return
first = True
for name, path in ckpt_paths.items():
print(f"Load {name} from {path}")
if first:
peft_loaded = PeftModel.from_pretrained(self.base_model, path, adapter_name=name)
first = False
else:
peft_loaded.load_adapter(path,adapter_name=name)
peft_loaded.set_adapter(name)
self.model = peft_loaded
def load_prefix(self,ckpt_path):
prefix_state_dict = torch.load(os.path.join(ckpt_path, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
print(f"Load prefix model successfully!")
import logging
import os
from typing import Any, Dict, List, Mapping, Optional
from langchain.llms.base import BaseLLM,LLM
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
from enum import Enum
from pydantic import root_validator, Field
from .xinghuo import SparkApi
from .xinghuo.ws import SparkAPI
logger = logging.getLogger(__name__)
text =[]
# length = 0
def getText(role,content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
class SparkLLM(LLM):
"""
ErnieLLM is a LLM that uses Ernie to generate text.
"""
appid: str = Field(
None,
description="APPID",
)
api_key: str = Field(
None,
description="API_KEY",
)
api_secret: str = Field(
None,
description="API_SECRET",
)
version: str = Field(
None,
description="version",
)
api: SparkAPI = Field(
None,
description="api",
)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
appid = get_from_dict_or_env(values, "appid", "XH_APPID", "")
api_key = get_from_dict_or_env(values, "api_key", "XH_API_KEY", "")
api_secret = get_from_dict_or_env(values, "api_secret", "XH_API_SECRET", "")
version = values.get("version", "v1")
if not appid:
raise ValueError("No appid provided.")
if not api_key:
raise ValueError("No api_key provided.")
if not api_secret:
raise ValueError("No api_secret provided.")
values["appid"] = appid
values["api_key"] = api_key
values["api_secret"] = api_secret
api=SparkAPI(appid,api_key,api_secret,version)
values["api"]=api
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
question = self.getText("user",prompt)
try:
# 你的代码
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
self.api.call(question)
response = self.api.answer
return response
except Exception as e:
# 处理异常
print("exception:",e)
raise e
def getText(self,role,content):
text = []
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinghuo"
\ No newline at end of file
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator
from typing import Dict, List, Optional
from transformers import PreTrainedModel, PreTrainedTokenizer
class WrapperLLM(LLM):
tokenizer: PreTrainedTokenizer = None
model: PreTrainedModel = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
if values.get("model") is None:
raise ValueError("No model provided.")
if values.get("tokenizer") is None:
raise ValueError("No tokenizer provided.")
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
resp,his = self.model.chat(self.tokenizer,prompt)
return resp
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "wrapper"
\ No newline at end of file
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
from urllib.parse import urlparse
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
answer = ""
class Ws_Param(object):
# 初始化
def __init__(self, APPID, APIKey, APISecret, Spark_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(Spark_url).netloc
self.path = urlparse(Spark_url).path
self.Spark_url = Spark_url
# 生成url
def create_url(self):
# 生成RFC1123格式的时间戳
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
# 拼接字符串
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
# 拼接鉴权参数,生成url
url = self.Spark_url + '?' + urlencode(v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
return url
# 收到websocket错误的处理
def on_error(ws, error):
print("### error:", error)
# 收到websocket关闭的处理
def on_close(ws,one,two):
print(" ")
# 收到websocket连接建立的处理
def on_open(ws):
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question))
ws.send(data)
# 收到websocket消息的处理
def on_message(ws, message):
# print(message)
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
print(content,end ="")
global answer
answer += content
# print(1)
if status == 2:
ws.close()
def gen_params(appid, domain,question):
"""
通过appid和用户的提问来生成请参数
"""
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
def main(appid, api_key, api_secret, Spark_url,domain, question):
# print("星火:")
global answer
answer = ""
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
ws.appid = appid
ws.question = question
ws.domain = domain
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
import SparkApi
#以下密钥信息从控制台获取
appid = "XXXXXXXX" #填写控制台中获取的 APPID 信息
api_secret = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" #填写控制台中获取的 APISecret 信息
api_key ="XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" #填写控制台中获取的 APIKey 信息
#用于配置大模型版本,默认“general/generalv2”
domain = "general" # v1.5版本
# domain = "generalv2" # v2.0版本
#云端环境的服务地址
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
# Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
text =[]
# length = 0
def getText(role,content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
if __name__ == '__main__':
text.clear
while(1):
Input = input("\n" +"我:")
question = checklen(getText("user",Input))
SparkApi.answer =""
print("星火:",end = "")
SparkApi.main(appid,api_key,api_secret,Spark_url,domain,question)
getText("assistant",SparkApi.answer)
# print(str(text))
import _thread as thread
import base64
import datetime
import hashlib
import hmac
import json
import ssl
from datetime import datetime
from time import mktime
from urllib.parse import urlparse, urlencode
from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
URL_V1_5="ws://spark-api.xf-yun.com/v1.1/chat"
URL_V2="ws://spark-api.xf-yun.com/v2.1/chat"
Domain_V1_5="general"
Domain_V2="generalv2"
class SparkAPI:
def __init__(self, APPID, APIKey, APISecret, Version="v1"):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
if Version == "v1":
self.Spark_url = URL_V1_5
self.domain = Domain_V1_5
elif Version == "v2":
self.Spark_url = URL_V2
self.domain = Domain_V2
self.host = urlparse(self.Spark_url).netloc
self.path = urlparse(self.Spark_url).path
self.answer = ""
def create_url(self):
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
v = {
"authorization": authorization,
"date": date,
"host": self.host
}
url = self.Spark_url + '?' + urlencode(v)
return url
def on_error(self, ws, error):
print("### error:", error)
def on_close(self, ws, one, two):
print(" ")
def on_open(self, ws):
thread.start_new_thread(self.run, (ws,))
def run(self, ws, *args):
data = json.dumps(self.gen_params(appid=self.APPID, domain=self.domain, question=ws.question))
ws.send(data)
def on_message(self, ws, message):
data = json.loads(message)
code = data['header']['code']
if code != 0:
print(f'请求错误: {code}, {data}')
ws.close()
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
# print(content, end="")
self.answer += content
if status == 2:
ws.close()
def gen_params(self, appid, domain, question):
data = {
"header": {
"app_id": appid,
"uid": "1234"
},
"parameter": {
"chat": {
"domain": domain,
"random_threshold": 0.5,
"max_tokens": 2048,
"auditing": "default"
}
},
"payload": {
"message": {
"text": question
}
}
}
return data
def call(self, question):
self.answer = ""
wsUrl = self.create_url()
websocket.enableTrace(False)
ws = websocket.WebSocketApp(wsUrl, on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open)
ws.question = question
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
\ No newline at end of file
"""资料分割"""
\ No newline at end of file
from abc import ABC, abstractmethod
class BaseCallback(ABC):
@abstractmethod
def filter(self,title:str,content:str) -> bool: #return True舍弃当前段落
pass
\ No newline at end of file
from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List
from src.loader.config import SENTENCE_SIZE
class ChineseTextSplitter(CharacterTextSplitter):
def __init__(self, pdf: bool = False, sentence_size: int = SENTENCE_SIZE, **kwargs):
super().__init__(**kwargs)
self.pdf = pdf
self.sentence_size = sentence_size
def split_text1(self, text: str) -> List[str]:
if self.pdf:
text = re.sub(r"\n{3,}", "\n", text)
text = re.sub('\s', ' ', text)
text = text.replace("\n\n", "")
sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :;
sent_list = []
for ele in sent_sep_pattern.split(text):
if sent_sep_pattern.match(ele) and sent_list:
sent_list[-1] += ele
elif ele:
sent_list.append(ele)
return sent_list
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
text = re.sub("\n\n", "", text)
text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # 单字符断句符
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # 英文省略号
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # 中文省略号
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
text = text.rstrip() # 段尾如果有多余的\n就去掉它
# 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
ls = [i for i in text.split("\n") if i]
for ele in ls:
if len(ele) > self.sentence_size:
ele1 = re.sub(r'([,,.]["’”」』]{0,2})([^,,.])', r'\1\n\2', ele)
ele1_ls = ele1.split("\n")
for ele_ele1 in ele1_ls:
if len(ele_ele1) > self.sentence_size:
ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
ele2_ls = ele_ele2.split("\n")
for ele_ele2 in ele2_ls:
if len(ele_ele2) > self.sentence_size:
ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
ele2_id = ele2_ls.index(ele_ele2)
ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
ele2_id + 1:]
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
return ls
# 文本分句长度
SENTENCE_SIZE = 100
ZH_TITLE_ENHANCE = False
\ No newline at end of file
from typing import List
from langchain.docstore.document import Document
import re
def under_non_alpha_ratio(text: str, threshold: float = 0.5):
"""Checks if the proportion of non-alpha characters in the text snippet exceeds a given
threshold. This helps prevent text like "-----------BREAK---------" from being tagged
as a title or narrative text. The ratio does not count spaces.
Parameters
----------
text
The input string to test
threshold
If the proportion of non-alpha characters exceeds this threshold, the function
returns False
"""
if len(text) == 0:
return False
alpha_count = len([char for char in text if char.strip() and char.isalpha()])
total_count = len([char for char in text if char.strip()])
try:
ratio = alpha_count / total_count
return ratio < threshold
except:
return False
def is_possible_title(
text: str,
title_max_word_length: int = 20,
non_alpha_threshold: float = 0.5,
) -> bool:
"""Checks to see if the text passes all of the checks for a valid title.
Parameters
----------
text
The input text to check
title_max_word_length
The maximum number of words a title can contain
non_alpha_threshold
The minimum number of alpha characters the text needs to be considered a title
"""
# 文本长度为0的话,肯定不是title
if len(text) == 0:
print("Not a title. Text is empty.")
return False
# 文本中有标点符号,就不是title
ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
if ENDS_IN_PUNCT_RE.search(text) is not None:
return False
# 文本长度不能超过设定值,默认20
# NOTE(robinson) - splitting on spaces here instead of word tokenizing because it
# is less expensive and actual tokenization doesn't add much value for the length check
if len(text) > title_max_word_length:
return False
# 文本中数字的占比不能太高,否则不是title
if under_non_alpha_ratio(text, threshold=non_alpha_threshold):
return False
# NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles
if text.endswith((",", ".", ",", "。")):
return False
if text.isnumeric():
print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore
return False
# 开头的字符内应该有数字,默认5个字符内
if len(text) < 5:
text_5 = text
else:
text_5 = text[:5]
alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5))))
if not alpha_in_text_5:
return False
return True
def zh_title_enhance(docs: List[Document]):
title = None
if len(docs) > 0:
for doc in docs:
if is_possible_title(doc.page_content):
doc.metadata['category'] = 'cn_Title'
title = doc.page_content
elif title:
doc.page_content = f"下文与({title})有关。{doc.page_content}"
return docs
else:
print("文件不存在")
from .c_db import UPostgresDB
import json
TABLE_CHAT = """
create table chat (
id varchar(1000) primary key,
user_id int,
info text,
chat_type_id int,
create_time date,
history json,
is_delete int,
status int
);
"""
class Chat:
def __init__(self, db: UPostgresDB) -> None:
self.db = db
def insert(self, value):
value[5] = json.dumps(value[5])
query = f"INSERT INTO chat(id,user_id,info,chat_type_id,create_time,history,is_delete,status) VALUES (%s,%s,%s,%s,%s,%s,%s,%s)"
self.db.execute_args(query, ((value[0],value[1],value[2],value[3],value[4],value[5],value[6],value[7])))
def delete_update(self,id):
query = f"UPDATE chat SET is_delete = 1 WHERE id = %s"
self.db.execute_args(query, (id,))
def search(self, id):
query = f"SELECT chat.id,user_id,info,t.name as type,chat.create_time,history,is_delete,status FROM chat left join chat_type t on t.id=chat.chat_type_id WHERE chat.id = %s and is_delete=0 ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def manual_search(self, id):
query = f"SELECT user_id,status,history FROM chat WHERE id = %s and is_delete=0"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def qa_search(self, id):
query = f"SELECT user_id,info,t.name as type,chat.create_time,history,status FROM chat left join chat_type t on t.id=chat.chat_type_id WHERE chat.id = %s and is_delete=0 ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def detail_search(self, id):
query = f"SELECT chat.id,user_id,info,t.name as type,chat.create_time,history,status FROM chat left join chat_type t on t.id=chat.chat_type_id WHERE chat.id = %s and is_delete=0 ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def delete_search(self, id):
query = f"SELECT user_id FROM chat WHERE id = %s and is_delete=0"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def list_chat_search(self, chat_type, user_id):
query = f"SELECT chat.id,info,t.name as type,chat.create_time,history,chat.status FROM chat left join chat_type t on t.id=chat.chat_type_id WHERE t.name = %s AND is_delete=0 AND user_id = %s ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (chat_type, user_id))
answer = self.db.fetchall()
if len(answer) > 0:
return answer
else:
return None
def search_history(self,id):
query = f"SELECT history FROM chat WHERE id = %s ORDER BY chat.create_time DESC;"
self.db.execute_args(query, (id,))
answer = self.db.fetchall()
if len(answer) > 0:
answer[0] = answer[0][0]
return answer[0]
else:
return None
def get_last_q(self):
query = f"SELECT id,create_time,info,chat_type_id FROM chat ORDER BY create_time DESC LIMIT 1 "
self.db.execute(query)
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
def create_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'chat')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if not exists:
query = TABLE_CHAT
self.db.execute(query)
def drop_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'chat')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if exists:
query = "DROP TABLE chat"
self.db.format(query)
print("drop table chat ok")
def update(self, chat_id, history):
history = json.dumps(history)
query = f"UPDATE chat SET history = %s WHERE id = %s"
self.db.execute_args(query, (history,chat_id))
def history_update(self, chat_id, history):
history = json.dumps(history)
query = f"UPDATE chat SET history = %s WHERE id = %s"
self.db.execute_args(query, (history,chat_id))
def search_type_id(self, chat_id):
query = f"select chat_type_id from chat where id = %s"
self.db.execute_args(query, (chat_id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0][0]
else:
return None
def get_chat_status(self, chat_id):
query = f"Select status from chat where id = %s"
self.db.execute_args(query, (chat_id,))
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0][0]
else:
return None
def set_chat_status(self, chat_id, status):
query = f"update chat set status = %s where id = %s"
self.db.execute_args(query, (status,chat_id))
\ No newline at end of file
"""资料存储相关"""
\ No newline at end of file
import os, sys
from os import path
sys.path.append("../")
from abc import ABC, abstractmethod
import json
from typing import List,Any,Tuple,Dict
from langchain.schema import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore,str2hash_base64
class DocumentCallback(ABC):
@abstractmethod #向量库储存前文档处理--
def before_store(self,docstore:PgSqlDocstore,documents):
pass
@abstractmethod #向量库查询后文档处理--用于结构建立
def after_search(self,docstore:PgSqlDocstore,documents:List[Tuple[Document, float]],number:int = 1000) -> List[Tuple[Document, float]]: #向量库查询后文档处理
pass
class DefaultDocumentCallback(DocumentCallback):
def before_store(self,docstore:PgSqlDocstore,documents):
output_doc = []
for doc in documents:
if "next_doc" in doc.metadata:
doc.metadata["next_hash"] = str2hash_base64(doc.metadata["next_doc"])
doc.metadata.pop("next_doc")
output_doc.append(doc)
return output_doc
def after_search(self,docstore:PgSqlDocstore,documents:List[Tuple[Document, float]],number:int = 1000) -> List[Tuple[Document, float]]: #向量库查询后文档处理
output_doc:List[Tuple[Document, float]] = []
exist_hash = []
for doc,score in documents:
print(exist_hash)
dochash = str2hash_base64(doc.page_content)
if dochash in exist_hash:
continue
else:
exist_hash.append(dochash)
output_doc.append((doc,score))
if len(output_doc) > number:
return output_doc
fordoc = doc
while ("next_hash" in fordoc.metadata):
if len(fordoc.metadata["next_hash"])>0:
if fordoc.metadata["next_hash"] in exist_hash:
break
else:
exist_hash.append(fordoc.metadata["next_hash"])
content = docstore.TXT_DOC.search(fordoc.metadata["next_hash"])
if content:
fordoc = Document(page_content=content[0], metadata=json.loads(content[1]))
output_doc.append((fordoc,score))
if len(output_doc) > number:
return output_doc
else:
break
else:
break
return output_doc
\ No newline at end of file
import psycopg2
class PostgresDB:
'''
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
dbname #指定数据库名。
user #指定连接数据库使用的用户名。
password #指定连接数据库使用的密码。
port #指定连接数据库的端口号。
connection_factory #指定创建连接对象的工厂类。
cursor_factory #指定创建游标对象的工厂类。
async_ #指定是否异步连接(默认False)。
sslmode #指定 SSL 模式。
sslrootcert #指定证书文件名。
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
'''
def __init__(self, host, database, user, password,port = 5432):
self.host = host
self.database = database
self.user = user
self.password = password
self.port = port
self.conn = None
self.cur = None
def connect(self):
self.conn = psycopg2.connect(
host=self.host,
database=self.database,
user=self.user,
password=self.password,
port = self.port
)
self.cur = self.conn.cursor()
def execute(self, query):
try:
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
def execute_args(self, query, args):
try:
self.cur.execute(query, args)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
def search(self, query, params=None):
self.cur.execute(query, params)
def fetchall(self):
return self.cur.fetchall()
def close(self):
self.cur.close()
self.conn.close()
def format(self, query):
try:
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
import sys
from os import path
# 这里相当于把当前目录添加到pythonpath中
sys.path.append(path.dirname(path.abspath(__file__)))
from typing import List,Union,Dict,Optional
from langchain.docstore.base import AddableMixin, Docstore
from k_db import PostgresDB
from .txt_doc_table import TxtDoc
from .vec_txt_table import TxtVector
import json,hashlib,base64
from langchain.schema import Document
def str2hash_base64(input:str) -> str:
# return f"%s" % hash(input)
return base64.b64encode(hashlib.sha1(input.encode()).digest()).decode()
class PgSqlDocstore(Docstore,AddableMixin):
host:str
dbname:str
username:str
password:str
port:str
'''
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
'''
def __getstate__(self):
return {"host":self.host,"dbname":self.dbname,"username":self.username,"password":self.password,"port":self.port}
def __setstate__(self, info):
self.__init__(info)
def __init__(self,info:dict,reset:bool = False):
self.host = info["host"]
self.dbname = info["dbname"]
self.username = info["username"]
self.password = info["password"]
self.port = info["port"] if "port" in info else "5432";
self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password,port=self.port)
self.TXT_DOC = TxtDoc(self.pgdb)
self.VEC_TXT = TxtVector(self.pgdb)
if reset:
self.__sub_init__()
self.TXT_DOC.drop_table()
self.VEC_TXT.drop_table()
self.TXT_DOC.create_table()
self.VEC_TXT.create_table()
def __sub_init__(self):
if not self.pgdb.conn:
self.pgdb.connect()
'''
从本地库中查找向量对应的文本段落,封装成Document返回
'''
def search(self, search: str) -> Union[str, Document]:
if not self.pgdb.conn:
self.__sub_init__()
anwser = self.VEC_TXT.search(search)
content = self.TXT_DOC.search(anwser[0])
if content:
return Document(page_content=content[0], metadata=json.loads(content[1]))
else:
return Document()
'''
从本地库中删除向量对应的文本,批量删除
'''
def delete(self, ids: List) -> None:
if not self.pgdb.conn:
self.__sub_init__()
pids = []
for id in ids:
anwser = self.VEC_TXT.search(id)
pids.append(anwser[0])
self.VEC_TXT.delete(ids)
self.TXT_DOC.delete(pids)
'''
向本地库添加向量和文本信息
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
'''
def add(self, texts: Dict[str, Document]) -> None:
# for vec,doc in texts.items():
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
if not self.pgdb.conn:
self.__sub_init__()
paragraph_hashs = [] #hash,text
paragraph_txts = []
vec_inserts = []
for vec,doc in texts.items():
txt_hash = str2hash_base64(doc.metadata["paragraph"])
print(txt_hash)
vec_inserts.append((vec,doc.page_content,txt_hash))
if txt_hash not in paragraph_hashs:
paragraph_hashs.append(txt_hash)
paragraph = doc.metadata["paragraph"]
doc.metadata.pop("paragraph")
paragraph_txts.append((txt_hash,paragraph,json.dumps(doc.metadata,ensure_ascii=False)))
# print(paragraph_txts)
self.TXT_DOC.insert(paragraph_txts)
self.VEC_TXT.insert(vec_inserts)
class InMemorySecondaryDocstore(Docstore, AddableMixin):
"""Simple in memory docstore in the form of a dict."""
def __init__(self, _dict: Optional[Dict[str, Document]] = None,_sec_dict: Optional[Dict[str, Document]] = None):
"""Initialize with dict."""
self._dict = _dict if _dict is not None else {}
self._sec_dict = _sec_dict if _sec_dict is not None else {}
def add(self, texts: Dict[str, Document]) -> None:
"""Add texts to in memory dictionary.
Args:
texts: dictionary of id -> document.
Returns:
None
"""
overlapping = set(texts).intersection(self._dict)
if overlapping:
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
self._dict = {**self._dict, **texts}
dict1 = {}
dict_sec = {}
for vec,doc in texts.items():
txt_hash = str2hash_base64(doc.metadata["paragraph"])
metadata=doc.metadata
paragraph = metadata.pop('paragraph')
# metadata.update({"paragraph_id":txt_hash})
metadata['paragraph_id']=txt_hash
dict_sec[txt_hash] = Document(page_content=paragraph,metadata=metadata)
dict1[vec] = Document(page_content=doc.page_content,metadata={'paragraph_id':txt_hash})
self._dict = {**self._dict, **dict1}
self._sec_dict = {**self._sec_dict, **dict_sec}
def delete(self, ids: List) -> None:
"""Deleting IDs from in memory dictionary."""
overlapping = set(ids).intersection(self._dict)
if not overlapping:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
self._sec_dict.pop(self._dict[id].metadata['paragraph_id'])
self._dict.pop(_id)
def search(self, search: str) -> Union[str, Document]:
"""Search via direct lookup.
Args:
search: id of a document to search for.
Returns:
Document if found, else error message.
"""
if search not in self._dict:
return f"ID {search} not found."
else:
print(self._dict[search].page_content)
return self._sec_dict[self._dict[search].metadata['paragraph_id']]
\ No newline at end of file
from .k_db import PostgresDB
# paragraph_id BIGSERIAL primary key,
TABLE_TXT_DOC = """
create table txt_doc (
hash varchar(40) primary key,
text text not null,
matadate text
);
"""
TABLE_TXT_DOC_HASH_INDEX = """
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
"""
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
class TxtDoc:
def __init__(self, db: PostgresDB) -> None:
self.db = db
def insert(self, texts):
query = f"INSERT INTO txt_doc(hash,text,matadate) VALUES "
args = []
for value in texts:
value = list(value)
query+= "(%s,%s,%s),"
args.extend(value)
query = query[:len(query)-1]
query += f"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
self.db.execute_args(query,args)
def delete(self,ids):
for id in ids:
query = f"delete FROM txt_doc WHERE hash = %s" % (id)
self.db.execute(query)
def search(self, id):
query = "SELECT text,matadate FROM txt_doc WHERE hash = %s"
self.db.execute_args(query,[id])
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
else:
return None
# return Document(page_content=self.db.fetchall()[0][0], metadata=dict(page=self.db.fetchall()[0][1]))
# answer = self.db.fetchall()[0][0]
# return answer
def create_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'txt_doc')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if not exists:
query = TABLE_TXT_DOC
self.db.execute(query)
# self.db.execute(TABLE_TXT_DOC_HASH_INDEX)
def drop_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'txt_doc')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if exists:
query = "DROP TABLE txt_doc"
self.db.format(query)
print("drop table txt_doc ok")
from .k_db import PostgresDB
TABLE_VEC_TXT = """
CREATE TABLE vec_txt (
vector_id varchar(36) PRIMARY KEY,
text text,
paragraph_id varchar(40) not null
)
"""
#025a9bee-2eb2-47f5-9722-525e05a0442b
class TxtVector:
def __init__(self, db: PostgresDB) -> None:
self.db = db
def insert(self, vectors):
query = f"INSERT INTO vec_txt(vector_id,text,paragraph_id) VALUES"
args = []
for value in vectors:
value = list(value)
query+= "(%s,%s,%s),"
args.extend(value)
query = query[:len(query)-1]
query += f"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
# query += ";"
self.db.execute_args(query,args)
def delete(self,ids):
for id in ids:
query = f"delete FROM vec_txt WHERE vector_id = '%s'" % (id,)
self.db.execute(query)
def search(self, search: str):
query = f"SELECT paragraph_id,text FROM vec_txt WHERE vector_id = %s"
self.db.execute_args(query,[search])
answer = self.db.fetchall()
print(answer)
return answer[0]
def create_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'vec_txt')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if not exists:
query = TABLE_VEC_TXT
self.db.execute(query)
def drop_table(self):
query = f"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'vec_txt')"
self.db.execute(query)
exists = self.db.fetchall()[0][0]
if exists:
query = "DROP TABLE vec_txt"
self.db.format(query)
print("drop table vec_txt ok")
\ No newline at end of file
import sys
sys.path.append("../")
import time
from src.loader.load import loads_path,loads
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
VEC_DB_DBNAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
SIMILARITY_SHOW_NUMBER,
KNOWLEDGE_PATH,
INDEX_NAME
)
from src.loader.callback import BaseCallback
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class localCallback(BaseCallback):
def filter(self,title:str,content:str) -> bool:
if len(title+content) == 0:
return True
return (len(title+content) / (len(title.splitlines())+len(content.splitlines())) < 20) or "思考题" in title
def test_faiss_from_dir():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD},
show_number=3,
reset=True)
docs = loads_path(KNOWLEDGE_PATH,mode="paged",sentence_size=512,callbacks=[localCallback()])
print(len(docs))
last_doc = None
docs1 = []
for doc in docs:
if not last_doc:
last_doc = doc
continue
if "font-size" not in doc.metadata or "page_number" not in doc.metadata:
continue
if doc.metadata["font-size"] == last_doc.metadata["font-size"] and doc.metadata["page_number"] == last_doc.metadata["page_number"] and len(doc.page_content)+len(last_doc.page_content) < 512/4*3:
last_doc.page_content += doc.page_content
else:
docs1.append(last_doc)
last_doc = doc
if last_doc:
docs1.append(last_doc)
docs = docs1
print(len(docs))
print(vecstore_faiss._faiss.index.ntotal)
for i in range(0, len(docs), 300):
vecstore_faiss._add_documents(docs[i:i+300 if i+300<len(docs) else len(docs)],need_split=True)
print(vecstore_faiss._faiss.index.ntotal)
vecstore_faiss._save_local()
def test_faiss_load():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("请介绍一下你理解的国际结算业务")))
if __name__ == "__main__":
test_faiss_from_dir()
test_faiss_load()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment