Commit 493cdd59 by 陈正乐

代码格式化

parent 9d8ee0af
......@@ -2,19 +2,19 @@
# 资料存储数据库配置
# =============================
VEC_DB_HOST = 'localhost'
VEC_DB_DBNAME='lae'
VEC_DB_USER='postgres'
VEC_DB_PASSWORD='chenzl'
VEC_DB_PORT='5432'
VEC_DB_DBNAME = 'lae'
VEC_DB_USER = 'postgres'
VEC_DB_PASSWORD = 'chenzl'
VEC_DB_PORT = '5432'
# =============================
# 聊天相关数据库配置
# =============================
CHAT_DB_HOST = 'localhost'
CHAT_DB_DBNAME='laechat'
CHAT_DB_USER='postgres'
CHAT_DB_PASSWORD='chenzl'
CHAT_DB_PORT='5432'
CHAT_DB_DBNAME = 'laechat'
CHAT_DB_USER = 'postgres'
CHAT_DB_PASSWORD = 'chenzl'
CHAT_DB_PORT = '5432'
# =============================
# 向量化模型路径配置
......@@ -41,4 +41,4 @@ INDEX_NAME = 'know'
# =============================
# 知识相关资料配置
# =============================
KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\work\\llm_gjjs\\兴火燎原知识库\\兴火燎原知识库\\law\\pdf'
\ No newline at end of file
KNOWLEDGE_PATH = 'C:\\Users\\15663\\Desktop\\work\\llm_gjjs\\兴火燎原知识库\\兴火燎原知识库\\law\\pdf'
"""各种大模型提供的服务"""
\ No newline at end of file
"""各种大模型提供的服务"""
import os
from typing import Dict, Optional,List
from langchain.llms.base import BaseLLM,LLM
from typing import Dict, Optional, List
from langchain.llms.base import BaseLLM, LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig,AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig
from pydantic import root_validator
class BaichuanLLM(LLM):
model_name: str = "baichuan-inc/Baichuan-13B-Chat"
quantization_bit: Optional[int] = None
tokenizer: AutoTokenizer = None
model: AutoModel = None
def _llm_type(self) -> str:
return "chatglm_local"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
def validate_environment(self, values: Dict) -> Dict:
if not values["model_name"]:
raise ValueError("No model name provided.")
model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=False,trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
......@@ -43,7 +40,7 @@ class BaichuanLLM(LLM):
print(f"Quantized to {values['quantization_bit']} bit")
model = model.quantize(values["quantization_bit"]).cuda()
else:
model=model.half().cuda()
model = model.half().cuda()
model = model.eval()
......@@ -51,14 +48,9 @@ class BaichuanLLM(LLM):
values["model"] = model
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
message = []
message.append({"role": "user", "content": prompt})
resp = self.model.chat(self.tokenizer,message)
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
message = [{"role": "user", "content": prompt}]
resp = self.model.chat(self.tokenizer, message)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp
\ No newline at end of file
return resp
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig
from transformers import AutoTokenizer, AutoModel, AutoConfig
import langchain
from langchain.llms.base import BaseLLM,LLM
from langchain.llms.base import BaseLLM, LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import aiohttp
import asyncio
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
......@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM):
tokenizer: AutoTokenizer = None
model: AutoModel = None
def _llm_type(self) -> str:
return "chatglm_local"
# @root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@staticmethod
def validate_environment(values: Dict) -> Dict:
if not values["model_name"]:
raise ValueError("No model name provided.")
model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name ,trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if values["pre_seq_len"]:
......@@ -56,7 +57,7 @@ class ChatGLMLocLLM(LLM):
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True).half().cuda()
if values["pre_seq_len"]:
# P-tuning v2
model = model.half().cuda()
......@@ -64,7 +65,7 @@ class ChatGLMLocLLM(LLM):
if values["quantization_bit"]:
print(f"Quantized to {values['quantization_bit']} bit")
model = model.quantize(values["quantization_bit"])
model = model.quantize(values["quantization_bit"])
model = model.eval()
......@@ -72,30 +73,26 @@ class ChatGLMLocLLM(LLM):
values["model"] = model
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
resp,his = self.model.chat(self.tokenizer,prompt)
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
resp, his = self.model.chat(self.tokenizer, prompt)
# print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp
class ChatGLMSerLLM(LLM):
# 模型服务url
url: str = "http://127.0.0.1:8000"
chat_history: dict = []
chat_history: dict = []
out_stream: bool = False
cache: bool = False
@property
def _llm_type(self) -> str:
return "chatglm3-6b"
def get_num_tokens(self, text: str) -> int:
resp = self._post(url=self.url+"/tokens",query=self._construct_query(text))
resp = self._post(url=self.url + "/tokens", query=self._construct_query(text))
if resp.status_code == 200:
resp_json = resp.json()
predictions = resp_json['response']
......@@ -103,28 +100,29 @@ class ChatGLMSerLLM(LLM):
return predictions
else:
return len(text)
def convert_data(self,data):
@staticmethod
def convert_data(data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
def _construct_query(self, prompt: str,temperature = 0.95) -> Dict:
def _construct_query(self, prompt: str, temperature=0.95) -> Dict:
"""构造请求体
"""
# self.chat_history.append({"role": "user", "content": prompt})
query = {
"prompt": prompt,
"history":self.chat_history,
"history": self.chat_history,
"max_length": 4096,
"top_p": 0.7,
"temperature": temperature
}
return query
@classmethod
def _post(self, url: str,
def _post(cls, url: str,
query: Dict) -> Any:
"""POST请求
"""
......@@ -135,51 +133,55 @@ class ChatGLMSerLLM(LLM):
headers=_headers,
timeout=300)
return resp
async def _post_stream(self, url: str,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,stream=False) -> Any:
@staticmethod
async def _post_stream(url: str,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream=False) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
async with aiohttp.ClientSession() as sess:
async with sess.post(url, json=query,headers=_headers,timeout=300) as response:
async with sess.post(url, json=query, headers=_headers, timeout=300) as response:
if response.status == 200:
if stream and not run_manager:
print('not callable')
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_start(None,None)
for _callable in run_manager.get_sync().handlers:
await _callable.on_llm_start(None, None)
async for chunk in response.content.iter_any():
# 处理每个块的数据
if chunk and run_manager:
for callable in run_manager.get_sync().handlers:
for _callable in run_manager.get_sync().handlers:
# print(chunk.decode("utf-8"),end="")
await callable.on_llm_new_token(chunk.decode("utf-8"))
await _callable.on_llm_new_token(chunk.decode("utf-8"))
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_end(None)
for _callable in run_manager.get_sync().handlers:
await _callable.on_llm_end(None)
else:
raise ValueError(f'glm 请求异常,http code:{response.status}')
def _call(self, prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream = False,
stream=False,
**kwargs: Any) -> str:
query = self._construct_query(prompt=prompt,temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
query = self._construct_query(prompt=prompt,
temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
# display("==============================")
# display(query)
# post
if stream or self.out_stream:
async def _post_stream():
await self._post_stream(url=self.url+"/stream",
query=query,run_manager=run_manager,stream=stream or self.out_stream)
await self._post_stream(url=self.url + "/stream",
query=query, run_manager=run_manager, stream=stream or self.out_stream)
asyncio.run(_post_stream())
return ''
else:
resp = self._post(url=self.url,
query=query)
query=query)
if resp.status_code == 200:
resp_json = resp.json()
......@@ -189,18 +191,19 @@ class ChatGLMSerLLM(LLM):
return predictions
else:
raise ValueError(f'glm 请求异常,http code:{resp.status_code}')
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
query = self._construct_query(prompt=prompt,temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
await self._post_stream(url=self.url+"/stream",
query=query,run_manager=run_manager,stream=self.out_stream)
return ''
query = self._construct_query(prompt=prompt,
temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
await self._post_stream(url=self.url + "/stream",
query=query, run_manager=run_manager, stream=self.out_stream)
return ''
@property
def _identifying_params(self) -> Mapping[str, Any]:
......@@ -209,4 +212,4 @@ class ChatGLMSerLLM(LLM):
_param_dict = {
"url": self.url
}
return _param_dict
\ No newline at end of file
return _param_dict
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator
import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig
from transformers import AutoTokenizer, AutoModel, AutoConfig
import langchain
from langchain.llms.base import BaseLLM,LLM
from langchain.llms.base import BaseLLM, LLM
from langchain_openai import OpenAI
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
class ChatGLMSerLLM(OpenAI):
def get_token_ids(self, text: str) -> List[int]:
if self.model_name.__contains__("chatglm"):
## 发起http请求,获取token_ids
url = f"{self.openai_api_base}/num_tokens"
query = {"prompt": text,"model": self.model_name}
_headers = {"Content_Type": "application/json","Authorization": "chatglm "+self.openai_api_key}
resp = self._post(url=url,query=query,headers= _headers)
query = {"prompt": text, "model": self.model_name}
_headers = {"Content_Type": "application/json", "Authorization": "chatglm " + self.openai_api_key}
resp = self._post(url=url, query=query, headers=_headers)
if resp.status_code == 200:
resp_json = resp.json()
print(resp_json)
......@@ -30,10 +31,10 @@ class ChatGLMSerLLM(OpenAI):
## predictions字符串转int
return [int(predictions)]
return [len(text)]
@classmethod
def _post(self, url: str,
query: Dict,headers: Dict) -> Any:
def _post(cls, url: str,
query: Dict, headers: Dict) -> Any:
"""POST请求
"""
_headers = {"Content_Type": "application/json"}
......@@ -43,4 +44,4 @@ class ChatGLMSerLLM(OpenAI):
json=query,
headers=_headers,
timeout=300)
return resp
\ No newline at end of file
return resp
......@@ -2,7 +2,7 @@ import logging
import os
from typing import Any, Dict, List, Mapping, Optional
from langchain.llms.base import BaseLLM,LLM
from langchain.llms.base import BaseLLM, LLM
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
......@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m
logger = logging.getLogger(__name__)
class ModelType(Enum):
ERNIE = "ernie"
ERNIE_LITE = "ernie-lite"
......@@ -25,7 +26,7 @@ class ModelType(Enum):
LLAMA2_13B = "llama2-13b"
LLAMA2_70B = "llama2-70b"
QFCN_LLAMA2_7B = "qfcn-llama2-7b"
BLOOMZ_7B="bloomz-7b"
BLOOMZ_7B = "bloomz-7b"
MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/"
......@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = {
ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1",
}
class ErnieLLM(LLM):
"""
ErnieLLM is a LLM that uses Ernie to generate text.
......@@ -52,27 +54,23 @@ class ErnieLLM(LLM):
access_token: Optional[str] = ""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
def validate_environment(self, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE)))
access_token = get_from_dict_or_env(values, "access_token", "ERNIE_ACCESS_TOKEN", "")
if not access_token:
raise ValueError("No access token provided.")
values["model_name"] = model_name
values["access_token"] = access_token
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
request = CompletionRequest(messages=[Message("user",prompt)])
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
request = CompletionRequest(messages=[Message("user", prompt)])
bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request)
try:
# 你的代码
......@@ -81,9 +79,8 @@ class ErnieLLM(LLM):
return response
except Exception as e:
# 处理异常
print("exception:",e)
print("exception:", e)
return e.__str__()
@property
def _llm_type(self) -> str:
......@@ -94,28 +91,25 @@ class ErnieLLM(LLM):
# return {
# "name": "ernie",
# }
def _get_model_service_url(model_name) -> str:
# print("_get_model_service_url model_name: ",model_name)
return MODEL_SERVICE_BASE_URL+MODEL_SERVICE_Suffix[model_name]
return MODEL_SERVICE_BASE_URL + MODEL_SERVICE_Suffix[model_name]
class ErnieChat(LLM):
model_name: ModelType
access_token: str
access_token: str
prefix_messages: List = Field(default_factory=list)
id: str = ""
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
msg = user_message(prompt)
request = CompletionRequest(messages=self.prefix_messages+[msg])
bot = ErnieBot(_get_model_service_url(self.model_name),self.access_token,request)
request = CompletionRequest(messages=self.prefix_messages + [msg])
bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token, request)
try:
# 你的代码
response = bot.get_response().result
......@@ -127,11 +121,11 @@ class ErnieChat(LLM):
except Exception as e:
# 处理异常
raise e
def _get_id(self) -> str:
return self.id
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "ernie"
\ No newline at end of file
return "ernie"
from dataclasses import asdict, dataclass
from typing import List
......@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field
from enum import Enum
class MessageRole(str, Enum):
USER = "user"
BOT = "assistant"
@dataclass
class Message:
role: str
content: str
@dataclass
class CompletionRequest:
messages: List[Message]
stream: bool = False
user: str = ""
@dataclass
class Usage:
prompt_tokens: int
completion_tokens: int
total_tokens: int
@dataclass
class CompletionResponse:
id: str
......@@ -42,12 +45,14 @@ class CompletionResponse:
is_safe: bool = False
is_truncated: bool = False
class ErrorResponse(BaseModel):
error_code: int = Field(...)
error_msg: str = Field(...)
id: str = Field(...)
class ErnieBot():
class ErnieBot:
url: str
access_token: str
request: CompletionRequest
......@@ -64,17 +69,19 @@ class ErnieBot():
headers = {'Content-Type': 'application/json'}
params = {'access_token': self.access_token}
request_dict = asdict(self.request)
response = requests.post(self.url, params=params,data=json.dumps(request_dict), headers=headers)
request_dict = asdict(self.request)
response = requests.post(self.url, params=params, data=json.dumps(request_dict), headers=headers)
# print(response.json())
try:
return CompletionResponse(**response.json())
except Exception as e:
print(e)
raise Exception(response.json())
def user_message(prompt: str) -> Message:
return Message(MessageRole.USER, prompt)
def bot_message(prompt: str) -> Message:
return Message(MessageRole.BOT, prompt)
\ No newline at end of file
return Message(MessageRole.BOT, prompt)
import os
import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator
from langchain.llms.base import BaseLLM,LLM
from langchain.llms.base import BaseLLM, LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import qianfan
from qianfan import ChatCompletion
# 启动llm的缓存
# langchain.llm_cache = InMemoryCache()
class ChatERNIESerLLM(LLM):
# 模型服务url
chat_completion:ChatCompletion = None
chat_completion: ChatCompletion = None
# url: str = "http://127.0.0.1:8000"
chat_history: dict = []
chat_history: dict = []
out_stream: bool = False
cache: bool = False
model_name:str = "ERNIE-Bot"
model_name: str = "ERNIE-Bot"
# def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
@property
def _llm_type(self) -> str:
return self.model_name
def get_num_tokens(self, text: str) -> int:
return len(text)
def convert_data(self,data):
@staticmethod
def convert_data(data):
result = []
for item in data:
result.append({'q': item[0], 'a': item[1]})
return result
def _call(self, prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream = False,
stream=False,
**kwargs: Any) -> str:
resp = self.chat_completion.do(model=self.model_name,messages=[{
resp = self.chat_completion.do(model=self.model_name, messages=[{
"role": "user",
"content": prompt
}])
......@@ -54,26 +54,26 @@ class ChatERNIESerLLM(LLM):
return resp.body["result"]
async def _post_stream(self,
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream=False) -> Any:
query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream=False) -> Any:
"""POST请求
"""
async for r in await self.chat_completion.ado(model=self.model_name,messages=[query], stream=stream):
async for r in await self.chat_completion.ado(model=self.model_name, messages=[query], stream=stream):
assert r.code == 200
if run_manager:
for callable in run_manager.get_sync().handlers:
await callable.on_llm_new_token(r.body["result"])
for _callable in run_manager.get_sync().handlers:
await _callable.on_llm_new_token(r.body["result"])
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
await self._post_stream(query={
"role": "user",
"content": prompt
},stream=True,run_manager=run_manager)
"role": "user",
"content": prompt
}, stream=True, run_manager=run_manager)
return ''
\ No newline at end of file
......@@ -4,6 +4,7 @@ import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig, DataCollatorForSeq2Seq
from peft import PeftModel
class ModelLoader:
def __init__(self, model_name_or_path, pre_seq_len=0, prefix_projection=False):
self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
......@@ -23,18 +24,19 @@ class ModelLoader:
def models(self):
return self.model, self.tokenizer
def collator(self):
return DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model)
def load_lora(self,ckpt_path,name="default"):
#训练时节约GPU占用
peft_loaded = PeftModel.from_pretrained(self.base_model,ckpt_path,adapter_name=name)
self.model = peft_loaded.merge_and_unload()
def load_lora(self, ckpt_path, name="default"):
# 训练时节约GPU占用
_peft_loaded = PeftModel.from_pretrained(self.base_model, ckpt_path, adapter_name=name)
self.model = _peft_loaded.merge_and_unload()
print(f"Load LoRA model successfully!")
def load_loras(self,ckpt_paths,name="default"):
if len(ckpt_paths)==0:
def load_loras(self, ckpt_paths, name="default"):
global peft_loaded
if len(ckpt_paths) == 0:
return
first = True
for name, path in ckpt_paths.items():
......@@ -42,12 +44,12 @@ class ModelLoader:
if first:
peft_loaded = PeftModel.from_pretrained(self.base_model, path, adapter_name=name)
first = False
else:
peft_loaded.load_adapter(path,adapter_name=name)
else:
peft_loaded.load_adapter(path, adapter_name=name)
peft_loaded.set_adapter(name)
self.model = peft_loaded
def load_prefix(self,ckpt_path):
def load_prefix(self, ckpt_path):
prefix_state_dict = torch.load(os.path.join(ckpt_path, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
......@@ -56,4 +58,3 @@ class ModelLoader:
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float()
print(f"Load prefix model successfully!")
......@@ -2,7 +2,7 @@ import logging
import os
from typing import Any, Dict, List, Mapping, Optional
from langchain.llms.base import BaseLLM,LLM
from langchain.llms.base import BaseLLM, LLM
from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
......@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI
logger = logging.getLogger(__name__)
text = []
text =[]
# length = 0
def getText(role,content):
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
def getText(role, content):
jsoncon = {"role": role, "content": content}
text.append(jsoncon)
return text
def getlength(text):
length = 0
for content in text:
......@@ -36,11 +35,12 @@ def getlength(text):
length += leng
return length
def checklen(text):
while (getlength(text) > 8000):
del text[0]
return text
def checklen(_text):
while getlength(_text) > 8000:
del _text[0]
return _text
class SparkLLM(LLM):
"""
......@@ -62,16 +62,16 @@ class SparkLLM(LLM):
None,
description="version",
)
api: SparkAPI = Field(
api: SparkAPI = Field(
None,
description="api",
)
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
def validate_environment(self, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
appid = get_from_dict_or_env(values, "appid", "XH_APPID", "")
api_key = get_from_dict_or_env(values, "api_key", "XH_API_KEY", "")
api_secret = get_from_dict_or_env(values, "api_secret", "XH_API_SECRET", "")
......@@ -84,23 +84,19 @@ class SparkLLM(LLM):
raise ValueError("No api_key provided.")
if not api_secret:
raise ValueError("No api_secret provided.")
values["appid"] = appid
values["api_key"] = api_key
values["api_secret"] = api_secret
api=SparkAPI(appid,api_key,api_secret,version)
values["api"]=api
api = SparkAPI(appid, api_key, api_secret, version)
values["api"] = api
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
question = self.getText("user",prompt)
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
question = self.getText("user", prompt)
try:
# 你的代码
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
......@@ -109,20 +105,18 @@ class SparkLLM(LLM):
return response
except Exception as e:
# 处理异常
print("exception:",e)
print("exception:", e)
raise e
def getText(self,role,content):
def getText(self, role, content):
text = []
jsoncon = {}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon)
return text
return text
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinghuo"
\ No newline at end of file
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator
from typing import Dict, List, Optional
from typing import Dict, List, Optional
from transformers import PreTrainedModel, PreTrainedTokenizer
class WrapperLLM(LLM):
tokenizer: PreTrainedTokenizer = None
model: PreTrainedModel = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
def validate_environment(self, values: Dict) -> Dict:
"""Validate the environment."""
# print(values)
if values.get("model") is None:
......@@ -19,16 +19,12 @@ class WrapperLLM(LLM):
raise ValueError("No tokenizer provided.")
return values
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
resp,his = self.model.chat(self.tokenizer,prompt)
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
resp, his = self.model.chat(self.tokenizer, prompt)
return resp
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "wrapper"
\ No newline at end of file
return "wrapper"
from abc import ABC, abstractmethod
class BaseCallback(ABC):
@abstractmethod
def filter(self,title:str,content:str) -> bool: #return True舍弃当前段落
pass
\ No newline at end of file
def filter(self, title: str, content: str) -> bool: # return True舍弃当前段落
pass
......@@ -25,7 +25,7 @@ class ChineseTextSplitter(CharacterTextSplitter):
sent_list.append(ele)
return sent_list
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
def split_text(self, text: str) -> List[str]: ##此处需要进一步优化逻辑
if self.pdf:
text = re.sub(r"\n{3,}", r"\n", text)
text = re.sub('\s', " ", text)
......@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter):
ele_id = ele1_ls.index(ele_ele1)
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
_id = ls.index(ele)
ls = ls[:_id] + [i for i in ele1_ls if i] + ls[_id + 1:]
return ls
# 文本分句长度
SENTENCE_SIZE = 100
ZH_TITLE_ENHANCE = False
\ No newline at end of file
ZH_TITLE_ENHANCE = False
......@@ -33,7 +33,7 @@ def is_possible_title(
title_max_word_length: int = 20,
non_alpha_threshold: float = 0.5,
) -> bool:
"""Checks to see if the text passes all of the checks for a valid title.
"""Checks to see if the text passes all the checks for a valid title.
Parameters
----------
......
import psycopg2
from psycopg2 import OperationalError, InterfaceError
class UPostgresDB:
'''
"""
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
......@@ -18,8 +19,9 @@ class UPostgresDB:
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
'''
def __init__(self, host, database, user, password,port = 5432):
"""
def __init__(self, host, database, user, password, port=5432):
self.host = host
self.database = database
self.user = user
......@@ -35,7 +37,7 @@ class UPostgresDB:
database=self.database,
user=self.user,
password=self.password,
port = self.port
port=self.port
)
self.cur = self.conn.cursor()
except Exception as e:
......@@ -45,7 +47,7 @@ class UPostgresDB:
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query)
self.cur.execute(query)
self.conn.commit()
except InterfaceError as e:
print(f"数据库连接已经关闭: {e}")
......@@ -53,8 +55,8 @@ class UPostgresDB:
print(f"数据库连接出现问题: {e}")
self.connect()
self.retry_execute(query)
except Exception as e:
print(f"执行sql语句出现错误: {e}")
except Exception as e:
print(f"执行sql语句出现错误: {e}")
self.conn.rollback()
def retry_execute(self, query):
......@@ -69,16 +71,16 @@ class UPostgresDB:
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query, args)
self.conn.commit()
self.cur.execute(query, args)
self.conn.commit()
except InterfaceError as e:
print(f"数据库连接已经关闭: {e}")
except OperationalError as e:
print(f"数据库操作出现问题: {e}")
self.connect()
self.retry_execute_args(query, args)
except Exception as e:
print(f"执行sql语句出现错误: {e}")
except Exception as e:
print(f"执行sql语句出现错误: {e}")
self.conn.rollback()
def retry_execute_args(self, query, args):
......@@ -89,7 +91,6 @@ class UPostgresDB:
print(f"重新执行sql语句再次出现错误: {type(e).__name__}: {e}")
self.conn.rollback()
def search(self, query, params=None):
if self.conn is None or self.conn.closed:
self.connect()
......@@ -97,7 +98,7 @@ class UPostgresDB:
def fetchall(self):
return self.cur.fetchall()
def fetchone(self):
return self.cur.fetchone()
......@@ -109,8 +110,8 @@ class UPostgresDB:
try:
if self.conn is None or self.conn.closed:
self.connect()
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
from .c_db import UPostgresDB
import json
TABLE_USER = """
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
......@@ -13,14 +14,15 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
"""
class CUser:
def __init__(self, db: UPostgresDB) -> None:
self.db = db
def insert(self, value):
query = f"INSERT INTO c_user(user_id, account, password) VALUES (%s,%s,%s)"
self.db.execute_args(query, ((value[0],value[1],value[2])))
self.db.execute_args(query, (value[0], value[1], value[2]))
def create_table(self):
query = TABLE_USER
self.db.execute(query)
\ No newline at end of file
self.db.execute(query)
from .c_db import UPostgresDB
import json
TABLE_CHAT = """
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
......@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
"""
class Chat:
def __init__(self, db: UPostgresDB) -> None:
self.db = db
......@@ -24,9 +26,9 @@ class Chat:
# 插入数据
def insert(self, value):
query = f"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (%s,%s,%s,%s)"
self.db.execute_args(query, ((value[0],value[1],value[2],value[3])))
self.db.execute_args(query, (value[0], value[1], value[2], value[3]))
# 创建表
def create_table(self):
query = TABLE_CHAT
self.db.execute(query)
\ No newline at end of file
self.db.execute(query)
from .c_db import UPostgresDB
import json
TABLE_CHAT = """
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
......@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
"""
class TurnQa:
def __init__(self, db: UPostgresDB) -> None:
self.db = db
......@@ -28,9 +30,9 @@ class TurnQa:
# 插入数据
def insert(self, value):
query = f"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s,%s)"
self.db.execute_args(query, ((value[0],value[1],value[2],value[3],value[4],value[5])))
self.db.execute_args(query, (value[0], value[1], value[2], value[3], value[4], value[5]))
# 创建表
def create_table(self):
query = TABLE_CHAT
self.db.execute(query)
\ No newline at end of file
self.db.execute(query)
import os, sys
from os import path
sys.path.append("../")
sys.path.append("../")
from abc import ABC, abstractmethod
import json
from typing import List,Any,Tuple,Dict
from typing import List, Any, Tuple, Dict
from langchain.schema import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore,str2hash_base64
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore, str2hash_base64
class DocumentCallback(ABC):
@abstractmethod #向量库储存前文档处理--
def before_store(self,docstore:PgSqlDocstore,documents):
@abstractmethod # 向量库储存前文档处理--
def before_store(self, docstore: PgSqlDocstore, documents):
pass
@abstractmethod #向量库查询后文档处理--用于结构建立
def after_search(self,docstore:PgSqlDocstore,documents:List[Tuple[Document, float]],number:int = 1000) -> List[Tuple[Document, float]]: #向量库查询后文档处理
@abstractmethod # 向量库查询后文档处理--用于结构建立
def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
List[Tuple[Document, float]]: # 向量库查询后文档处理
pass
class DefaultDocumentCallback(DocumentCallback):
def before_store(self,docstore:PgSqlDocstore,documents):
def before_store(self, docstore: PgSqlDocstore, documents):
output_doc = []
for doc in documents:
if "next_doc" in doc.metadata:
......@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback):
doc.metadata.pop("next_doc")
output_doc.append(doc)
return output_doc
def after_search(self,docstore:PgSqlDocstore,documents:List[Tuple[Document, float]],number:int = 1000) -> List[Tuple[Document, float]]: #向量库查询后文档处理
output_doc:List[Tuple[Document, float]] = []
def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
List[Tuple[Document, float]]: # 向量库查询后文档处理
output_doc: List[Tuple[Document, float]] = []
exist_hash = []
for doc,score in documents:
for doc, score in documents:
print(exist_hash)
dochash = str2hash_base64(doc.page_content)
if dochash in exist_hash:
continue
else:
exist_hash.append(dochash)
output_doc.append((doc,score))
output_doc.append((doc, score))
if len(output_doc) > number:
return output_doc
fordoc = doc
while ("next_hash" in fordoc.metadata):
if len(fordoc.metadata["next_hash"])>0:
while "next_hash" in fordoc.metadata:
if len(fordoc.metadata["next_hash"]) > 0:
if fordoc.metadata["next_hash"] in exist_hash:
break
else:
......@@ -50,11 +54,11 @@ class DefaultDocumentCallback(DocumentCallback):
content = docstore.TXT_DOC.search(fordoc.metadata["next_hash"])
if content:
fordoc = Document(page_content=content[0], metadata=json.loads(content[1]))
output_doc.append((fordoc,score))
output_doc.append((fordoc, score))
if len(output_doc) > number:
return output_doc
return output_doc
else:
break
else:
break
return output_doc
\ No newline at end of file
return output_doc
import psycopg2
class PostgresDB:
'''
"""
psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。
......@@ -17,8 +18,9 @@ class PostgresDB:
sslkey #指定私钥文件名。
sslcert #指定公钥文件名。
)
'''
def __init__(self, host, database, user, password,port = 5432):
"""
def __init__(self, host, database, user, password, port=5432):
self.host = host
self.database = database
self.user = user
......@@ -28,28 +30,29 @@ class PostgresDB:
self.cur = None
def connect(self):
self.conn = psycopg2.connect(
self.conn = psycopg2.connect(
host=self.host,
database=self.database,
user=self.user,
password=self.password,
port = self.port
port=self.port
)
self.cur = self.conn.cursor()
def execute(self, query):
try:
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
def execute_args(self, query, args):
try:
self.cur.execute(query, args)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.cur.execute(query, args)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
def search(self, query, params=None):
......@@ -64,8 +67,8 @@ class PostgresDB:
def format(self, query):
try:
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.cur.execute(query)
self.conn.commit()
except Exception as e:
print(f"An error occurred: {e}")
self.conn.rollback()
import sys
from os import path
# 这里相当于把当前目录添加到pythonpath中
sys.path.append(path.dirname(path.abspath(__file__)))
from typing import List,Union,Dict,Optional
from typing import List, Union, Dict, Optional
from langchain.docstore.base import AddableMixin, Docstore
from k_db import PostgresDB
from .txt_doc_table import TxtDoc
from .vec_txt_table import TxtVector
import json,hashlib,base64
import json, hashlib, base64
from langchain.schema import Document
def str2hash_base64(input:str) -> str:
def str2hash_base64(inp: str) -> str:
# return f"%s" % hash(input)
return base64.b64encode(hashlib.sha1(input.encode()).digest()).decode()
return base64.b64encode(hashlib.sha1(inp.encode()).digest()).decode()
class PgSqlDocstore(Docstore,AddableMixin):
host:str
dbname:str
username:str
password:str
port:str
class PgSqlDocstore(Docstore, AddableMixin):
host: str
dbname: str
username: str
password: str
port: str
'''
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
'''
def __getstate__(self):
return {"host":self.host,"dbname":self.dbname,"username":self.username,"password":self.password,"port":self.port}
return {"host": self.host, "dbname": self.dbname, "username": self.username, "password": self.password,
"port": self.port}
def __setstate__(self, info):
self.__init__(info)
def __init__(self,info:dict,reset:bool = False):
def __init__(self, info: dict, reset: bool = False):
self.host = info["host"]
self.dbname = info["dbname"]
self.username = info["username"]
self.password = info["password"]
self.port = info["port"] if "port" in info else "5432";
self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password,port=self.port)
self.port = info["port"] if "port" in info else "5432"
self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password, port=self.port)
self.TXT_DOC = TxtDoc(self.pgdb)
self.VEC_TXT = TxtVector(self.pgdb)
if reset:
......@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin):
self.VEC_TXT.drop_table()
self.TXT_DOC.create_table()
self.VEC_TXT.create_table()
def __sub_init__(self):
if not self.pgdb.conn:
self.pgdb.connect()
'''
从本地库中查找向量对应的文本段落,封装成Document返回
'''
def search(self, search: str) -> Union[str, Document]:
if not self.pgdb.conn:
self.__sub_init__()
......@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin):
return Document(page_content=content[0], metadata=json.loads(content[1]))
else:
return Document()
'''
从本地库中删除向量对应的文本,批量删除
'''
def delete(self, ids: List) -> None:
if not self.pgdb.conn:
self.__sub_init__()
pids = []
for id in ids:
anwser = self.VEC_TXT.search(id)
for item in ids:
anwser = self.VEC_TXT.search(item)
pids.append(anwser[0])
self.VEC_TXT.delete(ids)
self.TXT_DOC.delete(pids)
'''
向本地库添加向量和文本信息
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
'''
def add(self, texts: Dict[str, Document]) -> None:
# for vec,doc in texts.items():
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
if not self.pgdb.conn:
self.__sub_init__()
paragraph_hashs = [] #hash,text
paragraph_hashs = [] # hash,text
paragraph_txts = []
vec_inserts = []
for vec,doc in texts.items():
for vec, doc in texts.items():
txt_hash = str2hash_base64(doc.metadata["paragraph"])
print(txt_hash)
vec_inserts.append((vec,doc.page_content,txt_hash))
vec_inserts.append((vec, doc.page_content, txt_hash))
if txt_hash not in paragraph_hashs:
paragraph_hashs.append(txt_hash)
paragraph = doc.metadata["paragraph"]
doc.metadata.pop("paragraph")
paragraph_txts.append((txt_hash,paragraph,json.dumps(doc.metadata,ensure_ascii=False)))
paragraph_txts.append((txt_hash, paragraph, json.dumps(doc.metadata, ensure_ascii=False)))
# print(paragraph_txts)
self.TXT_DOC.insert(paragraph_txts)
self.VEC_TXT.insert(vec_inserts)
......@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin):
class InMemorySecondaryDocstore(Docstore, AddableMixin):
"""Simple in memory docstore in the form of a dict."""
def __init__(self, _dict: Optional[Dict[str, Document]] = None,_sec_dict: Optional[Dict[str, Document]] = None):
def __init__(self, _dict: Optional[Dict[str, Document]] = None, _sec_dict: Optional[Dict[str, Document]] = None):
"""Initialize with dict."""
self._dict = _dict if _dict is not None else {}
self._sec_dict = _sec_dict if _sec_dict is not None else {}
......@@ -123,19 +132,19 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if overlapping:
raise ValueError(f"Tried to add ids that already exist: {overlapping}")
self._dict = {**self._dict, **texts}
dict1 = {}
dict_sec = {}
for vec,doc in texts.items():
for vec, doc in texts.items():
txt_hash = str2hash_base64(doc.metadata["paragraph"])
metadata=doc.metadata
metadata = doc.metadata
paragraph = metadata.pop('paragraph')
# metadata.update({"paragraph_id":txt_hash})
metadata['paragraph_id']=txt_hash
dict_sec[txt_hash] = Document(page_content=paragraph,metadata=metadata)
dict1[vec] = Document(page_content=doc.page_content,metadata={'paragraph_id':txt_hash})
self._dict = {**self._dict, **dict1}
self._sec_dict = {**self._sec_dict, **dict_sec}
metadata['paragraph_id'] = txt_hash
dict_sec[txt_hash] = Document(page_content=paragraph, metadata=metadata)
dict1[vec] = Document(page_content=doc.page_content, metadata={'paragraph_id': txt_hash})
self._dict = {**self._dict, **dict1}
self._sec_dict = {**self._sec_dict, **dict_sec}
def delete(self, ids: List) -> None:
"""Deleting IDs from in memory dictionary."""
......@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if not overlapping:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
self._sec_dict.pop(self._dict[id].metadata['paragraph_id'])
self._sec_dict.pop(self._dict[_id].metadata['paragraph_id'])
self._dict.pop(_id)
def search(self, search: str) -> Union[str, Document]:
......@@ -159,4 +168,4 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
return f"ID {search} not found."
else:
print(self._dict[search].page_content)
return self._sec_dict[self._dict[search].metadata['paragraph_id']]
\ No newline at end of file
return self._sec_dict[self._dict[search].metadata['paragraph_id']]
from .k_db import PostgresDB
# paragraph_id BIGSERIAL primary key,
# paragraph_id BIGSERIAL primary key,
TABLE_TXT_DOC = """
create table txt_doc (
hash varchar(40) primary key,
......@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """
CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
"""
# CREATE UNIQUE INDEX idx_name ON your_table (column_name);
class TxtDoc:
def __init__(self, db: PostgresDB) -> None:
......@@ -21,19 +24,20 @@ class TxtDoc:
args = []
for value in texts:
value = list(value)
query+= "(%s,%s,%s),"
query += "(%s,%s,%s),"
args.extend(value)
query = query[:len(query)-1]
query = query[:len(query) - 1]
query += f"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
self.db.execute_args(query,args)
def delete(self,ids):
for id in ids:
query = f"delete FROM txt_doc WHERE hash = %s" % (id)
self.db.execute_args(query, args)
def delete(self, ids):
for item in ids:
query = f"delete FROM txt_doc WHERE hash = %s" % item
self.db.execute(query)
def search(self, id):
def search(self, item):
query = "SELECT text,matadate FROM txt_doc WHERE hash = %s"
self.db.execute_args(query,[id])
self.db.execute_args(query, [item])
answer = self.db.fetchall()
if len(answer) > 0:
return answer[0]
......@@ -60,4 +64,3 @@ class TxtDoc:
query = "DROP TABLE txt_doc"
self.db.format(query)
print("drop table txt_doc ok")
from .k_db import PostgresDB
TABLE_VEC_TXT = """
CREATE TABLE vec_txt (
vector_id varchar(36) PRIMARY KEY,
......@@ -6,7 +7,9 @@ CREATE TABLE vec_txt (
paragraph_id varchar(40) not null
)
"""
#025a9bee-2eb2-47f5-9722-525e05a0442b
# 025a9bee-2eb2-47f5-9722-525e05a0442b
class TxtVector:
def __init__(self, db: PostgresDB) -> None:
self.db = db
......@@ -16,19 +19,21 @@ class TxtVector:
args = []
for value in vectors:
value = list(value)
query+= "(%s,%s,%s),"
query += "(%s,%s,%s),"
args.extend(value)
query = query[:len(query)-1]
query = query[:len(query) - 1]
query += f"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
# query += ";"
self.db.execute_args(query,args)
def delete(self,ids):
for id in ids:
query = f"delete FROM vec_txt WHERE vector_id = '%s'" % (id,)
self.db.execute_args(query, args)
def delete(self, ids):
for item in ids:
query = f"delete FROM vec_txt WHERE vector_id = '%s'" % (item,)
self.db.execute(query)
def search(self, search: str):
query = f"SELECT paragraph_id,text FROM vec_txt WHERE vector_id = %s"
self.db.execute_args(query,[search])
self.db.execute_args(query, [search])
answer = self.db.fetchall()
print(answer)
return answer[0]
......@@ -48,4 +53,4 @@ class TxtVector:
if exists:
query = "DROP TABLE vec_txt"
self.db.format(query)
print("drop table vec_txt ok")
\ No newline at end of file
print("drop table vec_txt ok")
import sys
sys.path.append("../")
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.chat_table import Chat
from src.pgdb.chat.c_user_table import CUser
from src.pgdb.chat.turn_qa_table import TurnQa
"""测试会话相关数据可的连接"""
c_db = UPostgresDB(host="localhost", database="laechat", user="postgres", password="chenzl", port=5432)
chat = Chat(db=c_db)
c_user = CUser(db=c_db)
turn_qa = TurnQa(db=c_db)
chat.create_table()
c_user.create_table()
turn_qa.create_table()
# chat_id, user_id, info, deleted
chat.insert(["3333", "1111", "没有info", 0])
def test():
c_db = UPostgresDB(host="localhost", database="laechat", user="postgres", password="chenzl", port=5432)
chat = Chat(db=c_db)
c_user = CUser(db=c_db)
turn_qa = TurnQa(db=c_db)
chat.create_table()
c_user.create_table()
turn_qa.create_table()
# chat_id, user_id, info, deleted
chat.insert(["3333", "1111", "没有info", 0])
# user_id, account, password
c_user.insert(["111", "zhangsan", "111111"])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0])
# user_id, account, password
c_user.insert(["111", "zhangsan", "111111"])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0])
\ No newline at end of file
if __name__ == "main":
test()
import sys
sys.path.append("../")
import time
from src.loader.load import loads_path,loads
sys.path.append("../")
from src.loader.load import loads_path
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
VEC_DB_DBNAME,
......@@ -18,24 +18,27 @@ from src.config.consts import (
from src.loader.callback import BaseCallback
# 当返回值中带有“思考题”字样的时候,默认将其忽略。
class localCallback(BaseCallback):
def filter(self,title:str,content:str) -> bool:
if len(title+content) == 0:
def filter(self, title: str, content: str) -> bool:
if len(title + content) == 0:
return True
return (len(title+content) / (len(title.splitlines())+len(content.splitlines())) < 20) or "思考题" in title
return (len(title + content) / (len(title.splitlines()) + len(content.splitlines())) < 20) or "思考题" in title
"""测试资料入库(pgsql和faiss)"""
def test_faiss_from_dir():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=True)
docs = loads_path(KNOWLEDGE_PATH,mode="paged",sentence_size=512,callbacks=[localCallback()])
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=True)
docs = loads_path(KNOWLEDGE_PATH, mode="paged", sentence_size=512, callbacks=[localCallback()])
print(len(docs))
last_doc = None
docs1 = []
......@@ -45,7 +48,8 @@ def test_faiss_from_dir():
continue
if "font-size" not in doc.metadata or "page_number" not in doc.metadata:
continue
if doc.metadata["font-size"] == last_doc.metadata["font-size"] and doc.metadata["page_number"] == last_doc.metadata["page_number"] and len(doc.page_content)+len(last_doc.page_content) < 512/4*3:
if doc.metadata["font-size"] == last_doc.metadata["font-size"] and doc.metadata["page_number"] == \
last_doc.metadata["page_number"] and len(doc.page_content) + len(last_doc.page_content) < 512 / 4 * 3:
last_doc.page_content += doc.page_content
else:
docs1.append(last_doc)
......@@ -56,22 +60,26 @@ def test_faiss_from_dir():
print(len(docs))
print(vecstore_faiss._faiss.index.ntotal)
for i in range(0, len(docs), 300):
vecstore_faiss._add_documents(docs[i:i+300 if i+300<len(docs) else len(docs)],need_split=True)
vecstore_faiss._add_documents(docs[i:i + 300 if i + 300 < len(docs) else len(docs)], need_split=True)
print(vecstore_faiss._faiss.index.ntotal)
vecstore_faiss._save_local()
"""测试faiss向量数据库查询结果"""
def test_faiss_load():
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("征信业务有什么情况")))
if __name__ == "__main__":
# test_faiss_from_dir()
test_faiss_load()
\ No newline at end of file
test_faiss_load()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment