Commit 493cdd59 by 陈正乐

代码格式化

parent 9d8ee0af
...@@ -2,19 +2,19 @@ ...@@ -2,19 +2,19 @@
# 资料存储数据库配置 # 资料存储数据库配置
# ============================= # =============================
VEC_DB_HOST = 'localhost' VEC_DB_HOST = 'localhost'
VEC_DB_DBNAME='lae' VEC_DB_DBNAME = 'lae'
VEC_DB_USER='postgres' VEC_DB_USER = 'postgres'
VEC_DB_PASSWORD='chenzl' VEC_DB_PASSWORD = 'chenzl'
VEC_DB_PORT='5432' VEC_DB_PORT = '5432'
# ============================= # =============================
# 聊天相关数据库配置 # 聊天相关数据库配置
# ============================= # =============================
CHAT_DB_HOST = 'localhost' CHAT_DB_HOST = 'localhost'
CHAT_DB_DBNAME='laechat' CHAT_DB_DBNAME = 'laechat'
CHAT_DB_USER='postgres' CHAT_DB_USER = 'postgres'
CHAT_DB_PASSWORD='chenzl' CHAT_DB_PASSWORD = 'chenzl'
CHAT_DB_PORT='5432' CHAT_DB_PORT = '5432'
# ============================= # =============================
# 向量化模型路径配置 # 向量化模型路径配置
......
import os import os
from typing import Dict, Optional,List from typing import Dict, Optional, List
from langchain.llms.base import BaseLLM,LLM from langchain.llms.base import BaseLLM, LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
import torch import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig,AutoModelForCausalLM from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig from transformers.generation.utils import GenerationConfig
from pydantic import root_validator from pydantic import root_validator
class BaichuanLLM(LLM): class BaichuanLLM(LLM):
model_name: str = "baichuan-inc/Baichuan-13B-Chat" model_name: str = "baichuan-inc/Baichuan-13B-Chat"
quantization_bit: Optional[int] = None quantization_bit: Optional[int] = None
...@@ -19,17 +17,16 @@ class BaichuanLLM(LLM): ...@@ -19,17 +17,16 @@ class BaichuanLLM(LLM):
tokenizer: AutoTokenizer = None tokenizer: AutoTokenizer = None
model: AutoModel = None model: AutoModel = None
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "chatglm_local" return "chatglm_local"
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self, values: Dict) -> Dict:
if not values["model_name"]: if not values["model_name"]:
raise ValueError("No model name provided.") raise ValueError("No model name provided.")
model_name = values["model_name"] model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=False,trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_name, model_name,
torch_dtype=torch.float16, torch_dtype=torch.float16,
...@@ -43,7 +40,7 @@ class BaichuanLLM(LLM): ...@@ -43,7 +40,7 @@ class BaichuanLLM(LLM):
print(f"Quantized to {values['quantization_bit']} bit") print(f"Quantized to {values['quantization_bit']} bit")
model = model.quantize(values["quantization_bit"]).cuda() model = model.quantize(values["quantization_bit"]).cuda()
else: else:
model=model.half().cuda() model = model.half().cuda()
model = model.eval() model = model.eval()
...@@ -51,14 +48,9 @@ class BaichuanLLM(LLM): ...@@ -51,14 +48,9 @@ class BaichuanLLM(LLM):
values["model"] = model values["model"] = model
return values return values
def _call( def _call(self, prompt: str, stop: Optional[List[str]] = None,
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
prompt: str, message = [{"role": "user", "content": prompt}]
stop: Optional[List[str]] = None, resp = self.model.chat(self.tokenizer, message)
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
message = []
message.append({"role": "user", "content": prompt})
resp = self.model.chat(self.tokenizer,message)
# print(f"prompt:{prompt}\nresponse:{resp}\n") # print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp return resp
import os import os
import requests import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator from pydantic import root_validator
import torch import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig from transformers import AutoTokenizer, AutoModel, AutoConfig
import langchain import langchain
from langchain.llms.base import BaseLLM,LLM from langchain.llms.base import BaseLLM, LLM
from langchain.cache import InMemoryCache from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import aiohttp import aiohttp
import asyncio import asyncio
# 启动llm的缓存 # 启动llm的缓存
# langchain.llm_cache = InMemoryCache() # langchain.llm_cache = InMemoryCache()
...@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM): ...@@ -26,17 +27,17 @@ class ChatGLMLocLLM(LLM):
tokenizer: AutoTokenizer = None tokenizer: AutoTokenizer = None
model: AutoModel = None model: AutoModel = None
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "chatglm_local" return "chatglm_local"
# @root_validator() # @root_validator()
def validate_environment(cls, values: Dict) -> Dict: @staticmethod
def validate_environment(values: Dict) -> Dict:
if not values["model_name"]: if not values["model_name"]:
raise ValueError("No model name provided.") raise ValueError("No model name provided.")
model_name = values["model_name"] model_name = values["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name ,trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True) # model = AutoModel.from_pretrained(model_name, config=config, trust_remote_code=True)
if values["pre_seq_len"]: if values["pre_seq_len"]:
...@@ -72,18 +73,14 @@ class ChatGLMLocLLM(LLM): ...@@ -72,18 +73,14 @@ class ChatGLMLocLLM(LLM):
values["model"] = model values["model"] = model
return values return values
def _call( def _call(self, prompt: str, stop: Optional[List[str]] = None,
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
prompt: str, resp, his = self.model.chat(self.tokenizer, prompt)
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
resp,his = self.model.chat(self.tokenizer,prompt)
# print(f"prompt:{prompt}\nresponse:{resp}\n") # print(f"prompt:{prompt}\nresponse:{resp}\n")
return resp return resp
class ChatGLMSerLLM(LLM):
class ChatGLMSerLLM(LLM):
# 模型服务url # 模型服务url
url: str = "http://127.0.0.1:8000" url: str = "http://127.0.0.1:8000"
chat_history: dict = [] chat_history: dict = []
...@@ -95,7 +92,7 @@ class ChatGLMSerLLM(LLM): ...@@ -95,7 +92,7 @@ class ChatGLMSerLLM(LLM):
return "chatglm3-6b" return "chatglm3-6b"
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
resp = self._post(url=self.url+"/tokens",query=self._construct_query(text)) resp = self._post(url=self.url + "/tokens", query=self._construct_query(text))
if resp.status_code == 200: if resp.status_code == 200:
resp_json = resp.json() resp_json = resp.json()
predictions = resp_json['response'] predictions = resp_json['response']
...@@ -104,27 +101,28 @@ class ChatGLMSerLLM(LLM): ...@@ -104,27 +101,28 @@ class ChatGLMSerLLM(LLM):
else: else:
return len(text) return len(text)
@staticmethod
def convert_data(self,data): def convert_data(data):
result = [] result = []
for item in data: for item in data:
result.append({'q': item[0], 'a': item[1]}) result.append({'q': item[0], 'a': item[1]})
return result return result
def _construct_query(self, prompt: str,temperature = 0.95) -> Dict: def _construct_query(self, prompt: str, temperature=0.95) -> Dict:
"""构造请求体 """构造请求体
""" """
# self.chat_history.append({"role": "user", "content": prompt}) # self.chat_history.append({"role": "user", "content": prompt})
query = { query = {
"prompt": prompt, "prompt": prompt,
"history":self.chat_history, "history": self.chat_history,
"max_length": 4096, "max_length": 4096,
"top_p": 0.7, "top_p": 0.7,
"temperature": temperature "temperature": temperature
} }
return query return query
@classmethod @classmethod
def _post(self, url: str, def _post(cls, url: str,
query: Dict) -> Any: query: Dict) -> Any:
"""POST请求 """POST请求
""" """
...@@ -135,46 +133,50 @@ class ChatGLMSerLLM(LLM): ...@@ -135,46 +133,50 @@ class ChatGLMSerLLM(LLM):
headers=_headers, headers=_headers,
timeout=300) timeout=300)
return resp return resp
async def _post_stream(self, url: str,
@staticmethod
async def _post_stream(url: str,
query: Dict, query: Dict,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,stream=False) -> Any: run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, stream=False) -> Any:
"""POST请求 """POST请求
""" """
_headers = {"Content_Type": "application/json"} _headers = {"Content_Type": "application/json"}
async with aiohttp.ClientSession() as sess: async with aiohttp.ClientSession() as sess:
async with sess.post(url, json=query,headers=_headers,timeout=300) as response: async with sess.post(url, json=query, headers=_headers, timeout=300) as response:
if response.status == 200: if response.status == 200:
if stream and not run_manager: if stream and not run_manager:
print('not callable') print('not callable')
if run_manager: if run_manager:
for callable in run_manager.get_sync().handlers: for _callable in run_manager.get_sync().handlers:
await callable.on_llm_start(None,None) await _callable.on_llm_start(None, None)
async for chunk in response.content.iter_any(): async for chunk in response.content.iter_any():
# 处理每个块的数据 # 处理每个块的数据
if chunk and run_manager: if chunk and run_manager:
for callable in run_manager.get_sync().handlers: for _callable in run_manager.get_sync().handlers:
# print(chunk.decode("utf-8"),end="") # print(chunk.decode("utf-8"),end="")
await callable.on_llm_new_token(chunk.decode("utf-8")) await _callable.on_llm_new_token(chunk.decode("utf-8"))
if run_manager: if run_manager:
for callable in run_manager.get_sync().handlers: for _callable in run_manager.get_sync().handlers:
await callable.on_llm_end(None) await _callable.on_llm_end(None)
else: else:
raise ValueError(f'glm 请求异常,http code:{response.status}') raise ValueError(f'glm 请求异常,http code:{response.status}')
def _call(self, prompt: str, def _call(self, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream = False, stream=False,
**kwargs: Any) -> str: **kwargs: Any) -> str:
query = self._construct_query(prompt=prompt,temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95) query = self._construct_query(prompt=prompt,
temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
# display("==============================") # display("==============================")
# display(query) # display(query)
# post # post
if stream or self.out_stream: if stream or self.out_stream:
async def _post_stream(): async def _post_stream():
await self._post_stream(url=self.url+"/stream", await self._post_stream(url=self.url + "/stream",
query=query,run_manager=run_manager,stream=stream or self.out_stream) query=query, run_manager=run_manager, stream=stream or self.out_stream)
asyncio.run(_post_stream()) asyncio.run(_post_stream())
return '' return ''
else: else:
...@@ -197,9 +199,10 @@ class ChatGLMSerLLM(LLM): ...@@ -197,9 +199,10 @@ class ChatGLMSerLLM(LLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
query = self._construct_query(prompt=prompt,temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95) query = self._construct_query(prompt=prompt,
await self._post_stream(url=self.url+"/stream", temperature=kwargs["temperature"] if "temperature" in kwargs else 0.95)
query=query,run_manager=run_manager,stream=self.out_stream) await self._post_stream(url=self.url + "/stream",
query=query, run_manager=run_manager, stream=self.out_stream)
return '' return ''
@property @property
......
import os import os
import requests import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator from pydantic import root_validator
import torch import torch
from transformers import AutoTokenizer, AutoModel,AutoConfig from transformers import AutoTokenizer, AutoModel, AutoConfig
import langchain import langchain
from langchain.llms.base import BaseLLM,LLM from langchain.llms.base import BaseLLM, LLM
from langchain_openai import OpenAI from langchain_openai import OpenAI
from langchain.cache import InMemoryCache from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
class ChatGLMSerLLM(OpenAI): class ChatGLMSerLLM(OpenAI):
def get_token_ids(self, text: str) -> List[int]: def get_token_ids(self, text: str) -> List[int]:
...@@ -20,9 +21,9 @@ class ChatGLMSerLLM(OpenAI): ...@@ -20,9 +21,9 @@ class ChatGLMSerLLM(OpenAI):
## 发起http请求,获取token_ids ## 发起http请求,获取token_ids
url = f"{self.openai_api_base}/num_tokens" url = f"{self.openai_api_base}/num_tokens"
query = {"prompt": text,"model": self.model_name} query = {"prompt": text, "model": self.model_name}
_headers = {"Content_Type": "application/json","Authorization": "chatglm "+self.openai_api_key} _headers = {"Content_Type": "application/json", "Authorization": "chatglm " + self.openai_api_key}
resp = self._post(url=url,query=query,headers= _headers) resp = self._post(url=url, query=query, headers=_headers)
if resp.status_code == 200: if resp.status_code == 200:
resp_json = resp.json() resp_json = resp.json()
print(resp_json) print(resp_json)
...@@ -32,8 +33,8 @@ class ChatGLMSerLLM(OpenAI): ...@@ -32,8 +33,8 @@ class ChatGLMSerLLM(OpenAI):
return [len(text)] return [len(text)]
@classmethod @classmethod
def _post(self, url: str, def _post(cls, url: str,
query: Dict,headers: Dict) -> Any: query: Dict, headers: Dict) -> Any:
"""POST请求 """POST请求
""" """
_headers = {"Content_Type": "application/json"} _headers = {"Content_Type": "application/json"}
......
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import os import os
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from langchain.llms.base import BaseLLM,LLM from langchain.llms.base import BaseLLM, LLM
from langchain.schema import LLMResult from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
...@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m ...@@ -15,6 +15,7 @@ from .ernie_sdk import CompletionRequest, ErnieBot, Message, bot_message, user_m
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ModelType(Enum): class ModelType(Enum):
ERNIE = "ernie" ERNIE = "ernie"
ERNIE_LITE = "ernie-lite" ERNIE_LITE = "ernie-lite"
...@@ -25,7 +26,7 @@ class ModelType(Enum): ...@@ -25,7 +26,7 @@ class ModelType(Enum):
LLAMA2_13B = "llama2-13b" LLAMA2_13B = "llama2-13b"
LLAMA2_70B = "llama2-70b" LLAMA2_70B = "llama2-70b"
QFCN_LLAMA2_7B = "qfcn-llama2-7b" QFCN_LLAMA2_7B = "qfcn-llama2-7b"
BLOOMZ_7B="bloomz-7b" BLOOMZ_7B = "bloomz-7b"
MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/" MODEL_SERVICE_BASE_URL = "https://aip.baidubce.com/rpc/2.0/"
...@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = { ...@@ -43,6 +44,7 @@ MODEL_SERVICE_Suffix = {
ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1", ModelType.BLOOMZ_7B: "ai_custom/v1/wenxinworkshop/chat/bloomz_7b1",
} }
class ErnieLLM(LLM): class ErnieLLM(LLM):
""" """
ErnieLLM is a LLM that uses Ernie to generate text. ErnieLLM is a LLM that uses Ernie to generate text.
...@@ -52,7 +54,7 @@ class ErnieLLM(LLM): ...@@ -52,7 +54,7 @@ class ErnieLLM(LLM):
access_token: Optional[str] = "" access_token: Optional[str] = ""
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self, values: Dict) -> Dict:
"""Validate the environment.""" """Validate the environment."""
# print(values) # print(values)
model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE))) model_name = ModelType(get_from_dict_or_env(values, "model_name", "model_name", str(ModelType.ERNIE)))
...@@ -65,14 +67,10 @@ class ErnieLLM(LLM): ...@@ -65,14 +67,10 @@ class ErnieLLM(LLM):
values["access_token"] = access_token values["access_token"] = access_token
return values return values
def _call( def _call(self, prompt: str, stop: Optional[List[str]] = None,
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
request = CompletionRequest(messages=[Message("user",prompt)]) request = CompletionRequest(messages=[Message("user", prompt)])
bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request) bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token or "", request)
try: try:
# 你的代码 # 你的代码
...@@ -81,10 +79,9 @@ class ErnieLLM(LLM): ...@@ -81,10 +79,9 @@ class ErnieLLM(LLM):
return response return response
except Exception as e: except Exception as e:
# 处理异常 # 处理异常
print("exception:",e) print("exception:", e)
return e.__str__() return e.__str__()
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
...@@ -95,9 +92,10 @@ class ErnieLLM(LLM): ...@@ -95,9 +92,10 @@ class ErnieLLM(LLM):
# "name": "ernie", # "name": "ernie",
# } # }
def _get_model_service_url(model_name) -> str: def _get_model_service_url(model_name) -> str:
# print("_get_model_service_url model_name: ",model_name) # print("_get_model_service_url model_name: ",model_name)
return MODEL_SERVICE_BASE_URL+MODEL_SERVICE_Suffix[model_name] return MODEL_SERVICE_BASE_URL + MODEL_SERVICE_Suffix[model_name]
class ErnieChat(LLM): class ErnieChat(LLM):
...@@ -106,16 +104,12 @@ class ErnieChat(LLM): ...@@ -106,16 +104,12 @@ class ErnieChat(LLM):
prefix_messages: List = Field(default_factory=list) prefix_messages: List = Field(default_factory=list)
id: str = "" id: str = ""
def _call( def _call(self, prompt: str, stop: Optional[List[str]] = None,
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
msg = user_message(prompt) msg = user_message(prompt)
request = CompletionRequest(messages=self.prefix_messages+[msg]) request = CompletionRequest(messages=self.prefix_messages + [msg])
bot = ErnieBot(_get_model_service_url(self.model_name),self.access_token,request) bot = ErnieBot(_get_model_service_url(self.model_name), self.access_token, request)
try: try:
# 你的代码 # 你的代码
response = bot.get_response().result response = bot.get_response().result
......
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import List from typing import List
...@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field ...@@ -7,27 +5,32 @@ from pydantic import BaseModel, Field
from enum import Enum from enum import Enum
class MessageRole(str, Enum): class MessageRole(str, Enum):
USER = "user" USER = "user"
BOT = "assistant" BOT = "assistant"
@dataclass @dataclass
class Message: class Message:
role: str role: str
content: str content: str
@dataclass @dataclass
class CompletionRequest: class CompletionRequest:
messages: List[Message] messages: List[Message]
stream: bool = False stream: bool = False
user: str = "" user: str = ""
@dataclass @dataclass
class Usage: class Usage:
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
total_tokens: int total_tokens: int
@dataclass @dataclass
class CompletionResponse: class CompletionResponse:
id: str id: str
...@@ -42,12 +45,14 @@ class CompletionResponse: ...@@ -42,12 +45,14 @@ class CompletionResponse:
is_safe: bool = False is_safe: bool = False
is_truncated: bool = False is_truncated: bool = False
class ErrorResponse(BaseModel): class ErrorResponse(BaseModel):
error_code: int = Field(...) error_code: int = Field(...)
error_msg: str = Field(...) error_msg: str = Field(...)
id: str = Field(...) id: str = Field(...)
class ErnieBot():
class ErnieBot:
url: str url: str
access_token: str access_token: str
request: CompletionRequest request: CompletionRequest
...@@ -65,7 +70,7 @@ class ErnieBot(): ...@@ -65,7 +70,7 @@ class ErnieBot():
headers = {'Content-Type': 'application/json'} headers = {'Content-Type': 'application/json'}
params = {'access_token': self.access_token} params = {'access_token': self.access_token}
request_dict = asdict(self.request) request_dict = asdict(self.request)
response = requests.post(self.url, params=params,data=json.dumps(request_dict), headers=headers) response = requests.post(self.url, params=params, data=json.dumps(request_dict), headers=headers)
# print(response.json()) # print(response.json())
try: try:
return CompletionResponse(**response.json()) return CompletionResponse(**response.json())
...@@ -73,8 +78,10 @@ class ErnieBot(): ...@@ -73,8 +78,10 @@ class ErnieBot():
print(e) print(e)
raise Exception(response.json()) raise Exception(response.json())
def user_message(prompt: str) -> Message: def user_message(prompt: str) -> Message:
return Message(MessageRole.USER, prompt) return Message(MessageRole.USER, prompt)
def bot_message(prompt: str) -> Message: def bot_message(prompt: str) -> Message:
return Message(MessageRole.BOT, prompt) return Message(MessageRole.BOT, prompt)
import os import os
import requests import requests
from typing import Dict, Optional,List,Any,Mapping,Iterator from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator from pydantic import root_validator
from langchain.llms.base import BaseLLM,LLM from langchain.llms.base import BaseLLM, LLM
from langchain.cache import InMemoryCache from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import qianfan import qianfan
from qianfan import ChatCompletion from qianfan import ChatCompletion
# 启动llm的缓存 # 启动llm的缓存
# langchain.llm_cache = InMemoryCache() # langchain.llm_cache = InMemoryCache()
class ChatERNIESerLLM(LLM): class ChatERNIESerLLM(LLM):
# 模型服务url # 模型服务url
chat_completion:ChatCompletion = None chat_completion: ChatCompletion = None
# url: str = "http://127.0.0.1:8000" # url: str = "http://127.0.0.1:8000"
chat_history: dict = [] chat_history: dict = []
out_stream: bool = False out_stream: bool = False
cache: bool = False cache: bool = False
model_name:str = "ERNIE-Bot" model_name: str = "ERNIE-Bot"
# def __init__(self): # def __init__(self):
# self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu") # self.chat_completion = qianfan.ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")
...@@ -32,20 +33,19 @@ class ChatERNIESerLLM(LLM): ...@@ -32,20 +33,19 @@ class ChatERNIESerLLM(LLM):
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
return len(text) return len(text)
def convert_data(self,data): @staticmethod
def convert_data(data):
result = [] result = []
for item in data: for item in data:
result.append({'q': item[0], 'a': item[1]}) result.append({'q': item[0], 'a': item[1]})
return result return result
def _call(self, prompt: str, def _call(self, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream = False, stream=False,
**kwargs: Any) -> str: **kwargs: Any) -> str:
resp = self.chat_completion.do(model=self.model_name,messages=[{ resp = self.chat_completion.do(model=self.model_name, messages=[{
"role": "user", "role": "user",
"content": prompt "content": prompt
}]) }])
...@@ -59,11 +59,12 @@ class ChatERNIESerLLM(LLM): ...@@ -59,11 +59,12 @@ class ChatERNIESerLLM(LLM):
stream=False) -> Any: stream=False) -> Any:
"""POST请求 """POST请求
""" """
async for r in await self.chat_completion.ado(model=self.model_name,messages=[query], stream=stream): async for r in await self.chat_completion.ado(model=self.model_name, messages=[query], stream=stream):
assert r.code == 200 assert r.code == 200
if run_manager: if run_manager:
for callable in run_manager.get_sync().handlers: for _callable in run_manager.get_sync().handlers:
await callable.on_llm_new_token(r.body["result"]) await _callable.on_llm_new_token(r.body["result"])
async def _acall( async def _acall(
self, self,
prompt: str, prompt: str,
...@@ -74,6 +75,5 @@ class ChatERNIESerLLM(LLM): ...@@ -74,6 +75,5 @@ class ChatERNIESerLLM(LLM):
await self._post_stream(query={ await self._post_stream(query={
"role": "user", "role": "user",
"content": prompt "content": prompt
},stream=True,run_manager=run_manager) }, stream=True, run_manager=run_manager)
return '' return ''
\ No newline at end of file
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig, DataCollatorForSeq2Seq from transformers import AutoModel, AutoTokenizer, AutoConfig, DataCollatorForSeq2Seq
from peft import PeftModel from peft import PeftModel
class ModelLoader: class ModelLoader:
def __init__(self, model_name_or_path, pre_seq_len=0, prefix_projection=False): def __init__(self, model_name_or_path, pre_seq_len=0, prefix_projection=False):
self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) self.config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
...@@ -27,14 +28,15 @@ class ModelLoader: ...@@ -27,14 +28,15 @@ class ModelLoader:
def collator(self): def collator(self):
return DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model) return DataCollatorForSeq2Seq(tokenizer=self.tokenizer, model=self.model)
def load_lora(self,ckpt_path,name="default"): def load_lora(self, ckpt_path, name="default"):
#训练时节约GPU占用 # 训练时节约GPU占用
peft_loaded = PeftModel.from_pretrained(self.base_model,ckpt_path,adapter_name=name) _peft_loaded = PeftModel.from_pretrained(self.base_model, ckpt_path, adapter_name=name)
self.model = peft_loaded.merge_and_unload() self.model = _peft_loaded.merge_and_unload()
print(f"Load LoRA model successfully!") print(f"Load LoRA model successfully!")
def load_loras(self,ckpt_paths,name="default"): def load_loras(self, ckpt_paths, name="default"):
if len(ckpt_paths)==0: global peft_loaded
if len(ckpt_paths) == 0:
return return
first = True first = True
for name, path in ckpt_paths.items(): for name, path in ckpt_paths.items():
...@@ -43,11 +45,11 @@ class ModelLoader: ...@@ -43,11 +45,11 @@ class ModelLoader:
peft_loaded = PeftModel.from_pretrained(self.base_model, path, adapter_name=name) peft_loaded = PeftModel.from_pretrained(self.base_model, path, adapter_name=name)
first = False first = False
else: else:
peft_loaded.load_adapter(path,adapter_name=name) peft_loaded.load_adapter(path, adapter_name=name)
peft_loaded.set_adapter(name) peft_loaded.set_adapter(name)
self.model = peft_loaded self.model = peft_loaded
def load_prefix(self,ckpt_path): def load_prefix(self, ckpt_path):
prefix_state_dict = torch.load(os.path.join(ckpt_path, "pytorch_model.bin")) prefix_state_dict = torch.load(os.path.join(ckpt_path, "pytorch_model.bin"))
new_prefix_state_dict = {} new_prefix_state_dict = {}
for k, v in prefix_state_dict.items(): for k, v in prefix_state_dict.items():
...@@ -56,4 +58,3 @@ class ModelLoader: ...@@ -56,4 +58,3 @@ class ModelLoader:
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model.transformer.prefix_encoder.float() self.model.transformer.prefix_encoder.float()
print(f"Load prefix model successfully!") print(f"Load prefix model successfully!")
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import os import os
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Mapping, Optional
from langchain.llms.base import BaseLLM,LLM from langchain.llms.base import BaseLLM, LLM
from langchain.schema import LLMResult from langchain.schema import LLMResult
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks
...@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI ...@@ -16,18 +16,17 @@ from .xinghuo.ws import SparkAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
text = []
text =[]
# length = 0 # length = 0
def getText(role,content): def getText(role, content):
jsoncon = {} jsoncon = {"role": role, "content": content}
jsoncon["role"] = role
jsoncon["content"] = content
text.append(jsoncon) text.append(jsoncon)
return text return text
def getlength(text): def getlength(text):
length = 0 length = 0
for content in text: for content in text:
...@@ -36,10 +35,11 @@ def getlength(text): ...@@ -36,10 +35,11 @@ def getlength(text):
length += leng length += leng
return length return length
def checklen(text):
while (getlength(text) > 8000): def checklen(_text):
del text[0] while getlength(_text) > 8000:
return text del _text[0]
return _text
class SparkLLM(LLM): class SparkLLM(LLM):
...@@ -68,7 +68,7 @@ class SparkLLM(LLM): ...@@ -68,7 +68,7 @@ class SparkLLM(LLM):
) )
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self, values: Dict) -> Dict:
"""Validate the environment.""" """Validate the environment."""
# print(values) # print(values)
...@@ -89,18 +89,14 @@ class SparkLLM(LLM): ...@@ -89,18 +89,14 @@ class SparkLLM(LLM):
values["api_key"] = api_key values["api_key"] = api_key
values["api_secret"] = api_secret values["api_secret"] = api_secret
api=SparkAPI(appid,api_key,api_secret,version) api = SparkAPI(appid, api_key, api_secret, version)
values["api"]=api values["api"] = api
return values return values
def _call( def _call(self, prompt: str, stop: Optional[List[str]] = None,
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
prompt: str, question = self.getText("user", prompt)
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
question = self.getText("user",prompt)
try: try:
# 你的代码 # 你的代码
# SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question) # SparkApi.main(self.appid,self.api_key,self.api_secret,self.Spark_url,self.domain,question)
...@@ -109,10 +105,10 @@ class SparkLLM(LLM): ...@@ -109,10 +105,10 @@ class SparkLLM(LLM):
return response return response
except Exception as e: except Exception as e:
# 处理异常 # 处理异常
print("exception:",e) print("exception:", e)
raise e raise e
def getText(self,role,content): def getText(self, role, content):
text = [] text = []
jsoncon = {} jsoncon = {}
jsoncon["role"] = role jsoncon["role"] = role
...@@ -124,5 +120,3 @@ class SparkLLM(LLM): ...@@ -124,5 +120,3 @@ class SparkLLM(LLM):
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "xinghuo" return "xinghuo"
\ No newline at end of file
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator from pydantic import root_validator
from typing import Dict, List, Optional from typing import Dict, List, Optional
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
class WrapperLLM(LLM): class WrapperLLM(LLM):
tokenizer: PreTrainedTokenizer = None tokenizer: PreTrainedTokenizer = None
model: PreTrainedModel = None model: PreTrainedModel = None
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(self, values: Dict) -> Dict:
"""Validate the environment.""" """Validate the environment."""
# print(values) # print(values)
if values.get("model") is None: if values.get("model") is None:
...@@ -19,13 +19,9 @@ class WrapperLLM(LLM): ...@@ -19,13 +19,9 @@ class WrapperLLM(LLM):
raise ValueError("No tokenizer provided.") raise ValueError("No tokenizer provided.")
return values return values
def _call( def _call(self, prompt: str, stop: Optional[List[str]] = None,
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs) -> str:
prompt: str, resp, his = self.model.chat(self.tokenizer, prompt)
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
resp,his = self.model.chat(self.tokenizer,prompt)
return resp return resp
@property @property
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class BaseCallback(ABC): class BaseCallback(ABC):
@abstractmethod @abstractmethod
def filter(self,title:str,content:str) -> bool: #return True舍弃当前段落 def filter(self, title: str, content: str) -> bool: # return True舍弃当前段落
pass pass
...@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter): ...@@ -56,6 +56,6 @@ class ChineseTextSplitter(CharacterTextSplitter):
ele_id = ele1_ls.index(ele_ele1) ele_id = ele1_ls.index(ele_ele1)
ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:] ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]
id = ls.index(ele) _id = ls.index(ele)
ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:] ls = ls[:_id] + [i for i in ele1_ls if i] + ls[_id + 1:]
return ls return ls
import os,copy import os, copy
from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader,UnstructuredPDFLoader,UnstructuredWordDocumentLoader,PDFMinerPDFasHTMLLoader from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader, UnstructuredPDFLoader, \
UnstructuredWordDocumentLoader, PDFMinerPDFasHTMLLoader
from .config import SENTENCE_SIZE, ZH_TITLE_ENHANCE from .config import SENTENCE_SIZE, ZH_TITLE_ENHANCE
from .chinese_text_splitter import ChineseTextSplitter from .chinese_text_splitter import ChineseTextSplitter
from .zh_title_enhance import zh_title_enhance from .zh_title_enhance import zh_title_enhance
from langchain.schema import Document from langchain.schema import Document
from typing import List,Dict,Optional from typing import List, Dict, Optional
from src.loader.callback import BaseCallback from src.loader.callback import BaseCallback
import re import re
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
def load(filepath,mode:str = None,sentence_size:int = 0,metadata = None,callbacks = None,**kwargs):
def load(filepath, mode: str = None, sentence_size: int = 0, metadata=None, callbacks=None, **kwargs):
r""" r"""
加载文档,参数说明 加载文档,参数说明
mode:文档切割方式,"single", "elements", "paged" mode:文档切割方式,"single", "elements", "paged"
...@@ -19,37 +21,44 @@ def load(filepath,mode:str = None,sentence_size:int = 0,metadata = None,callback ...@@ -19,37 +21,44 @@ def load(filepath,mode:str = None,sentence_size:int = 0,metadata = None,callback
kwargs kwargs
""" """
if filepath.lower().endswith(".md"): if filepath.lower().endswith(".md"):
loader = UnstructuredFileLoader(filepath, mode=mode or "elements",**kwargs) loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
elif filepath.lower().endswith(".txt"): elif filepath.lower().endswith(".txt"):
loader = TextLoader(filepath, autodetect_encoding=True,**kwargs) loader = TextLoader(filepath, autodetect_encoding=True, **kwargs)
elif filepath.lower().endswith(".csv"): elif filepath.lower().endswith(".csv"):
loader = CSVLoader(filepath,**kwargs) loader = CSVLoader(filepath, **kwargs)
elif filepath.lower().endswith(".pdf"): elif filepath.lower().endswith(".pdf"):
# loader = UnstructuredPDFLoader(filepath, mode=mode or "elements",**kwargs) # loader = UnstructuredPDFLoader(filepath, mode=mode or "elements",**kwargs)
# 使用自定义pdf loader # 使用自定义pdf loader
return __pdf_loader(filepath,sentence_size=sentence_size,metadata=metadata,callbacks=callbacks) return __pdf_loader(filepath, sentence_size=sentence_size, metadata=metadata, callbacks=callbacks)
elif filepath.lower().endswith(".docx") or filepath.lower().endswith(".doc"): elif filepath.lower().endswith(".docx") or filepath.lower().endswith(".doc"):
loader = UnstructuredWordDocumentLoader(filepath, mode=mode or "elements", **kwargs) loader = UnstructuredWordDocumentLoader(filepath, mode=mode or "elements", **kwargs)
else: else:
loader = UnstructuredFileLoader(filepath, mode=mode or "elements",**kwargs) loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
if sentence_size > 0: if sentence_size > 0:
return split(loader.load(),sentence_size) return split(loader.load(), sentence_size)
return loader.load() return loader.load()
def loads_path(path:str,**kwargs):
return loads(get_files_in_directory(path),**kwargs)
def loads(filepaths,**kwargs): def loads_path(path: str, **kwargs):
default_kwargs={"mode":"paged"} return loads(get_files_in_directory(path), **kwargs)
def loads(filepaths, **kwargs):
default_kwargs = {"mode": "paged"}
default_kwargs.update(**kwargs) default_kwargs.update(**kwargs)
documents = [load(filepath=file, **default_kwargs) for file in filepaths] documents = [load(filepath=file, **default_kwargs) for file in filepaths]
return [item for sublist in documents for item in sublist] return [item for sublist in documents for item in sublist]
def append(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE):#保留文档结构信息,注意处理hash
def append(documents=None, sentence_size: int = SENTENCE_SIZE): # 保留文档结构信息,注意处理hash
if documents is None:
documents = []
effect_documents = [] effect_documents = []
last_doc = documents[0] last_doc = documents[0]
for doc in documents[1:]: for doc in documents[1:]:
last_hash = "" if "next_hash" not in last_doc.metadata else last_doc.metadata["next_hash"] last_hash = "" if "next_hash" not in last_doc.metadata else last_doc.metadata["next_hash"]
doc_hash = "" if "next_hash" not in doc.metadata else doc.metadata["next_hash"] doc_hash = "" if "next_hash" not in doc.metadata else doc.metadata["next_hash"]
if len(last_doc.page_content)+len(doc.page_content) <= sentence_size and last_hash == doc_hash: if len(last_doc.page_content) + len(doc.page_content) <= sentence_size and last_hash == doc_hash:
last_doc.page_content = last_doc.page_content + doc.page_content last_doc.page_content = last_doc.page_content + doc.page_content
continue continue
else: else:
...@@ -58,28 +67,31 @@ def append(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE):#保 ...@@ -58,28 +67,31 @@ def append(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE):#保
effect_documents.append(last_doc) effect_documents.append(last_doc)
return effect_documents return effect_documents
def split(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE): #保留文档结构信息,注意处理hash
def split(documents=None, sentence_size: int = SENTENCE_SIZE): # 保留文档结构信息,注意处理hash
if documents is None:
documents = []
effect_documents = [] effect_documents = []
for doc in documents: for doc in documents:
if len(doc.page_content) > sentence_size: if len(doc.page_content) > sentence_size:
words_list = re.split(r'·-·', doc.page_content.replace("。","。·-·").replace("\n","\n·-·")) #插入分隔符,分割 words_list = re.split(r'·-·', doc.page_content.replace("。", "。·-·").replace("\n", "\n·-·")) # 插入分隔符,分割
document = Document(page_content="",metadata=copy.deepcopy(doc.metadata)) document = Document(page_content="", metadata=copy.deepcopy(doc.metadata))
first = True first = True
for word in words_list: for word in words_list:
if len(document.page_content) + len(word) < sentence_size: if len(document.page_content) + len(word) < sentence_size:
document.page_content += word document.page_content += word
else: else:
if len(document.page_content.replace(" ","").replace("\n",""))>0: if len(document.page_content.replace(" ", "").replace("\n", "")) > 0:
if first: if first:
first=False first = False
else: else:
effect_documents[-1].metadata["next_doc"] = document.page_content effect_documents[-1].metadata["next_doc"] = document.page_content
effect_documents.append(document) effect_documents.append(document)
document = Document(page_content=word,metadata=copy.deepcopy(doc.metadata)) document = Document(page_content=word, metadata=copy.deepcopy(doc.metadata))
if len(document.page_content.replace(" ","").replace("\n",""))>0: if len(document.page_content.replace(" ", "").replace("\n", "")) > 0:
if first: if first:
first=False pass
else: else:
effect_documents[-1].metadata["next_doc"] = document.page_content effect_documents[-1].metadata["next_doc"] = document.page_content
effect_documents.append(document) effect_documents.append(document)
...@@ -87,10 +99,12 @@ def split(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE): #保 ...@@ -87,10 +99,12 @@ def split(documents:List[Document] = [],sentence_size:int = SENTENCE_SIZE): #保
effect_documents.append(doc) effect_documents.append(doc)
return effect_documents return effect_documents
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE,mode:str = None,**kwargs):
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE, mode: str = None,
**kwargs):
print("load_file", filepath) print("load_file", filepath)
if filepath.lower().endswith(".md"): if filepath.lower().endswith(".md"):
loader = UnstructuredFileLoader(filepath, mode=mode or "elements",**kwargs) loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
docs = loader.load() docs = loader.load()
elif filepath.lower().endswith(".txt"): elif filepath.lower().endswith(".txt"):
loader = TextLoader(filepath, autodetect_encoding=True, **kwargs) loader = TextLoader(filepath, autodetect_encoding=True, **kwargs)
...@@ -100,15 +114,15 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T ...@@ -100,15 +114,15 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
loader = CSVLoader(filepath, **kwargs) loader = CSVLoader(filepath, **kwargs)
docs = loader.load() docs = loader.load()
elif filepath.lower().endswith(".pdf"): elif filepath.lower().endswith(".pdf"):
loader = UnstructuredPDFLoader(filepath, mode=mode or "elements",**kwargs) loader = UnstructuredPDFLoader(filepath, mode=mode or "elements", **kwargs)
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter) docs = loader.load_and_split(textsplitter)
elif filepath.lower().endswith(".docx"): elif filepath.lower().endswith(".docx"):
loader = UnstructuredWordDocumentLoader(filepath, mode=mode or "elements",**kwargs) loader = UnstructuredWordDocumentLoader(filepath, mode=mode or "elements", **kwargs)
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(textsplitter) docs = loader.load_and_split(textsplitter)
else: else:
loader = UnstructuredFileLoader(filepath, mode=mode or "elements",**kwargs) loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size)
docs = loader.load_and_split(text_splitter=textsplitter) docs = loader.load_and_split(text_splitter=textsplitter)
if using_zh_title_enhance: if using_zh_title_enhance:
...@@ -116,6 +130,7 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T ...@@ -116,6 +130,7 @@ def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_T
write_check_file(filepath, docs) write_check_file(filepath, docs)
return docs return docs
def write_check_file(filepath, docs): def write_check_file(filepath, docs):
folder_path = os.path.join(os.path.dirname(filepath), "tmp_files") folder_path = os.path.join(os.path.dirname(filepath), "tmp_files")
if not os.path.exists(folder_path): if not os.path.exists(folder_path):
...@@ -129,6 +144,7 @@ def write_check_file(filepath, docs): ...@@ -129,6 +144,7 @@ def write_check_file(filepath, docs):
fout.write('\n') fout.write('\n')
fout.close() fout.close()
def get_files_in_directory(directory): def get_files_in_directory(directory):
file_paths = [] file_paths = []
for root, dirs, files in os.walk(directory): for root, dirs, files in os.walk(directory):
...@@ -137,21 +153,29 @@ def get_files_in_directory(directory): ...@@ -137,21 +153,29 @@ def get_files_in_directory(directory):
file_paths.append(file_path) file_paths.append(file_path)
return file_paths return file_paths
#自定义pdf load部分
def __checkV(strings:str): # 自定义pdf load部分
def __checkV(strings: str):
lines = len(strings.splitlines()) lines = len(strings.splitlines())
if (lines > 3 and len(strings.replace(" ", ""))/lines < 15): if lines > 3 and len(strings.replace(" ", "")) / lines < 15:
return False return False
return True return True
def __isTitle(strings:str):
return len(strings.splitlines())==1 and len(strings)>0 and strings.endswith("\n")
def __appendPara(strings:str): def __isTitle(strings: str):
return strings.replace(".\n","^_^").replace("。\n","^-^").replace("?\n","?^-^").replace("?\n","?^-^").replace("\n","").replace("^_^",".\n").replace("^-^","。\n").replace("?^-^","?\n").replace("?^-^","?\n") return len(strings.splitlines()) == 1 and len(strings) > 0 and strings.endswith("\n")
def __check_fs_ff(line_ff_fs_s,fs,ff): #若当前行有上一行一样的字体、字号文字,则返回相同的。默认返回最长文本的字体和字号
def __appendPara(strings: str):
return strings.replace(".\n", "^_^").replace("。\n", "^-^").replace("?\n", "?^-^").replace("?\n", "?^-^").replace(
"\n", "").replace("^_^", ".\n").replace("^-^", "。\n").replace("?^-^", "?\n").replace("?^-^", "?\n")
def __check_fs_ff(line_ff_fs_s, fs, ff): # 若当前行有上一行一样的字体、字号文字,则返回相同的。默认返回最长文本的字体和字号
re_fs = line_ff_fs_s[-1][0][-1] re_fs = line_ff_fs_s[-1][0][-1]
re_ff = line_ff_fs_s[-1][1][-1] if line_ff_fs_s[-1][1] else None re_ff = line_ff_fs_s[-1][1][-1] if line_ff_fs_s[-1][1] else None
max_len = 0 max_len = 0
for ff_fs in line_ff_fs_s: #寻找最长文本字体和字号 for ff_fs in line_ff_fs_s: # 寻找最长文本字体和字号
c_max = max(list(map(int, ff_fs[0]))) c_max = max(list(map(int, ff_fs[0])))
if max_len < ff_fs[2] or (max_len == ff_fs[2] and c_max > int(re_fs)): if max_len < ff_fs[2] or (max_len == ff_fs[2] and c_max > int(re_fs)):
max_len = ff_fs[2] max_len = ff_fs[2]
...@@ -163,122 +187,132 @@ def __check_fs_ff(line_ff_fs_s,fs,ff): #若当前行有上一行一样的字 ...@@ -163,122 +187,132 @@ def __check_fs_ff(line_ff_fs_s,fs,ff): #若当前行有上一行一样的字
re_fs = fs re_fs = fs
re_ff = ff re_ff = ff
break break
return int(re_fs),re_ff return int(re_fs), re_ff
def append_document(snippets1:List[Document],title:str,content:str,callbacks,font_size,page_num,metadate,need_append:bool = False): def append_document(snippets1: List[Document], title: str, content: str, callbacks, font_size, page_num, metadate,
need_append: bool = False):
if callbacks: if callbacks:
for cb in callbacks: for cb in callbacks:
if isinstance(cb,BaseCallback): if isinstance(cb, BaseCallback):
if cb.filter(title,content): if cb.filter(title, content):
return return
if need_append and len(snippets1)>0: if need_append and len(snippets1) > 0:
ps = snippets1.pop() ps = snippets1.pop()
snippets1.append(Document(page_content=ps.page_content+title, metadata=ps.metadata)) snippets1.append(Document(page_content=ps.page_content + title, metadata=ps.metadata))
else: else:
doc_metadata = {"font-size": font_size,"page_number":page_num} doc_metadata = {"font-size": font_size, "page_number": page_num}
doc_metadata.update(metadate) doc_metadata.update(metadate)
snippets1.append(Document(page_content=title+content, metadata=doc_metadata)) snippets1.append(Document(page_content=title + content, metadata=doc_metadata))
''' '''
提取pdf文档,按标题和内容进行分割,文档的页码按标题所在页码为准 提取pdf文档,按标题和内容进行分割,文档的页码按标题所在页码为准
分割后的文本按sentence_size值再次分割,分割的文本的页码均属于父文本的页码 分割后的文本按sentence_size值再次分割,分割的文本的页码均属于父文本的页码
''' '''
def __pdf_loader(filepath:str,sentence_size:int = 0,metadata = None,callbacks = None):
def __pdf_loader(filepath: str, sentence_size: int = 0, metadata=None, callbacks=None):
if not filepath.lower().endswith(".pdf"): if not filepath.lower().endswith(".pdf"):
raise ValueError("file is not pdf document") raise ValueError("file is not pdf document")
loader = PDFMinerPDFasHTMLLoader(filepath) loader = PDFMinerPDFasHTMLLoader(filepath)
documents = loader.load() documents = loader.load()
soup = BeautifulSoup(documents[0].page_content,'html.parser') soup = BeautifulSoup(documents[0].page_content, 'html.parser')
content = soup.find_all('div') content = soup.find_all('div')
cur_fs = None #当前文本font-size cur_fs = None # 当前文本font-size
last_fs = None #上一段文本font-size last_fs = None # 上一段文本font-size
cur_ff = None #当前文本风格 cur_ff = None # 当前文本风格
cur_text = '' cur_text = ''
fs_increasing = False #下一行字体变大,判断为标题,从此处分割 fs_increasing = False # 下一行字体变大,判断为标题,从此处分割
last_text = '' last_text = ''
last_page_num = 1 #上一页页码 根据page_split判断当前文本页码 last_page_num = 1 # 上一页页码 根据page_split判断当前文本页码
page_num = 1 #初始页码 page_num = 1 # 初始页码
page_change = False #页面切换 page_change = False # 页面切换
page_split = False #页面是否出现文本分割 page_split = False # 页面是否出现文本分割
last_is_title = False #上一个文本是否是标题 last_is_title = False # 上一个文本是否是标题
snippets:List[Document] = [] snippets: List[Document] = []
filename = os.path.basename(filepath) filename = os.path.basename(filepath)
if metadata: if metadata:
metadata.update({'source':filepath,'filename':filename,'filetype': 'application/pdf'}) metadata.update({'source': filepath, 'filename': filename, 'filetype': 'application/pdf'})
else: else:
metadata = {'source':filepath,'filename':filename,'filetype': 'application/pdf'} metadata = {'source': filepath, 'filename': filename, 'filetype': 'application/pdf'}
for c in content: for c in content:
divs = c.get('style') divs = c.get('style')
if re.match(r"^(Page|page)",c.text): #检测当前页的页码 if re.match(r"^(Page|page)", c.text): # 检测当前页的页码
match = re.match(r"^(page|Page)\s+(\d+)",c.text) match = re.match(r"^(page|Page)\s+(\d+)", c.text)
if match: if match:
if page_split: #如果有文本分割,则换页,没有则保持当前文本起始页码 if page_split: # 如果有文本分割,则换页,没有则保持当前文本起始页码
last_page_num = page_num last_page_num = page_num
page_num = match.group(2) page_num = match.group(2)
if len(last_text)+len(cur_text) == 0: #如果翻页且文本为空,上一页页码为当前页码 if len(last_text) + len(cur_text) == 0: # 如果翻页且文本为空,上一页页码为当前页码
last_page_num = page_num last_page_num = page_num
page_change = True page_change = True
page_split = False page_split = False
continue continue
if re.findall('writing-mode:(.*?);',divs) == ['False'] or re.match(r'^[0-9\s\n]+$',c.text) or re.match(r"^第\s+\d+\s+页$",c.text): #如果不显示或者纯数字 if re.findall('writing-mode:(.*?);', divs) == ['False'] or re.match(r'^[0-9\s\n]+$', c.text) or re.match(
r"^第\s+\d+\s+页$", c.text): # 如果不显示或者纯数字
continue continue
if len(c.text.replace("\n","").replace(" ","")) <= 1: #去掉有效字符小于1的行 if len(c.text.replace("\n", "").replace(" ", "")) <= 1: # 去掉有效字符小于1的行
continue continue
sps = c.find_all('span') sps = c.find_all('span')
if not sps: if not sps:
continue continue
line_ff_fs_s = [] #有效字符大于1的集合 line_ff_fs_s = [] # 有效字符大于1的集合
line_ff_fs_s2 = [] #有效字符为1的集合 line_ff_fs_s2 = [] # 有效字符为1的集合
for sp in sps: #如果一行中有多个不同样式的 for sp in sps: # 如果一行中有多个不同样式的
sp_len = len(sp.text.replace("\n","").replace(" ","")) sp_len = len(sp.text.replace("\n", "").replace(" ", ""))
if sp_len > 0: if sp_len > 0:
st = sp.get('style') st = sp.get('style')
if st: if st:
ff_fs = (re.findall('font-size:(\d+)px',st),re.findall('font-family:(.*?);',st),len(sp.text.replace("\n","").replace(" ",""))) ff_fs = (re.findall('font-size:(\d+)px', st), re.findall('font-family:(.*?);', st),
if sp_len == 1: #过滤一个有效字符的span len(sp.text.replace("\n", "").replace(" ", "")))
if sp_len == 1: # 过滤一个有效字符的span
line_ff_fs_s2.append(ff_fs) line_ff_fs_s2.append(ff_fs)
else: else:
line_ff_fs_s.append(ff_fs) line_ff_fs_s.append(ff_fs)
if len(line_ff_fs_s)==0: #如果为空,则以一个有效字符span为准 if len(line_ff_fs_s) == 0: # 如果为空,则以一个有效字符span为准
if len(line_ff_fs_s2)>0: if len(line_ff_fs_s2) > 0:
line_ff_fs_s = line_ff_fs_s2 line_ff_fs_s = line_ff_fs_s2
else: else:
if len(c.text)>0: if len(c.text) > 0:
page_change = False page_change = False
continue continue
fs,ff = __check_fs_ff(line_ff_fs_s,cur_fs,cur_ff) fs, ff = __check_fs_ff(line_ff_fs_s, cur_fs, cur_ff)
if not cur_ff: if not cur_ff:
cur_ff = ff cur_ff = ff
if not cur_fs: if not cur_fs:
cur_fs = fs cur_fs = fs
if (abs(fs - cur_fs) <= 1 and ff == cur_ff): #风格和字体都没改变 if abs(fs - cur_fs) <= 1 and ff == cur_ff: # 风格和字体都没改变
cur_text += c.text cur_text += c.text
cur_fs = fs cur_fs = fs
page_change = False page_change = False
if len(cur_text.splitlines()) > 3: #连续多行则fs_increasing不再生效 if len(cur_text.splitlines()) > 3: # 连续多行则fs_increasing不再生效
fs_increasing = False fs_increasing = False
else: else:
if page_change and cur_fs > fs+1: #翻页,(字体变小) 大概率是页眉,跳过c.text。-----有可能切掉一行文本 if page_change and cur_fs > fs + 1: # 翻页,(字体变小) 大概率是页眉,跳过c.text。-----有可能切掉一行文本
page_change = False page_change = False
continue continue
if last_is_title: #如果上一个为title if last_is_title: # 如果上一个为title
if __isTitle(cur_text) or fs_increasing: #连续多个title 或者 有变大标识的 if __isTitle(cur_text) or fs_increasing: # 连续多个title 或者 有变大标识的
last_text = last_text + cur_text last_text = last_text + cur_text
last_is_title = True last_is_title = True
fs_increasing = False fs_increasing = False
else: else:
append_document(snippets,last_text,__appendPara(cur_text),callbacks,cur_fs,page_num if page_split else last_page_num,metadata) append_document(snippets, last_text, __appendPara(cur_text), callbacks, cur_fs,
page_num if page_split else last_page_num, metadata)
page_split = True page_split = True
last_text = '' last_text = ''
last_is_title = False last_is_title = False
fs_increasing = int(fs) > int(cur_fs) #字体变大 fs_increasing = int(fs) > int(cur_fs) # 字体变大
else: else:
if len(last_text)>0 and __checkV(last_text): #过滤部分文本 if len(last_text) > 0 and __checkV(last_text): # 过滤部分文本
#将跨页的两段或者行数较少的文本合并 # 将跨页的两段或者行数较少的文本合并
append_document(snippets,__appendPara(last_text),"",callbacks,last_fs,page_num if page_split else last_page_num,metadata,need_append=len(last_text.splitlines()) <= 2 or page_change) append_document(snippets, __appendPara(last_text), "", callbacks, last_fs,
page_num if page_split else last_page_num, metadata,
need_append=len(last_text.splitlines()) <= 2 or page_change)
page_split = True page_split = True
last_text = cur_text last_text = cur_text
last_is_title = __isTitle(last_text) or fs_increasing last_is_title = __isTitle(last_text) or fs_increasing
...@@ -290,7 +324,8 @@ def __pdf_loader(filepath:str,sentence_size:int = 0,metadata = None,callbacks = ...@@ -290,7 +324,8 @@ def __pdf_loader(filepath:str,sentence_size:int = 0,metadata = None,callbacks =
cur_ff = ff cur_ff = ff
cur_text = c.text cur_text = c.text
page_change = False page_change = False
append_document(snippets,last_text,__appendPara(cur_text),callbacks,cur_fs,page_num if page_split else last_page_num,metadata) append_document(snippets, last_text, __appendPara(cur_text), callbacks, cur_fs,
page_num if page_split else last_page_num, metadata)
if sentence_size > 0: if sentence_size > 0:
return split(snippets,sentence_size) return split(snippets, sentence_size)
return snippets return snippets
...@@ -33,7 +33,7 @@ def is_possible_title( ...@@ -33,7 +33,7 @@ def is_possible_title(
title_max_word_length: int = 20, title_max_word_length: int = 20,
non_alpha_threshold: float = 0.5, non_alpha_threshold: float = 0.5,
) -> bool: ) -> bool:
"""Checks to see if the text passes all of the checks for a valid title. """Checks to see if the text passes all the checks for a valid title.
Parameters Parameters
---------- ----------
......
import psycopg2 import psycopg2
from psycopg2 import OperationalError, InterfaceError from psycopg2 import OperationalError, InterfaceError
class UPostgresDB: class UPostgresDB:
''' """
psycopg2.connect( psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。 dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。 host #指定连接数据库的主机名。
...@@ -18,8 +19,9 @@ class UPostgresDB: ...@@ -18,8 +19,9 @@ class UPostgresDB:
sslkey #指定私钥文件名。 sslkey #指定私钥文件名。
sslcert #指定公钥文件名。 sslcert #指定公钥文件名。
) )
''' """
def __init__(self, host, database, user, password,port = 5432):
def __init__(self, host, database, user, password, port=5432):
self.host = host self.host = host
self.database = database self.database = database
self.user = user self.user = user
...@@ -35,7 +37,7 @@ class UPostgresDB: ...@@ -35,7 +37,7 @@ class UPostgresDB:
database=self.database, database=self.database,
user=self.user, user=self.user,
password=self.password, password=self.password,
port = self.port port=self.port
) )
self.cur = self.conn.cursor() self.cur = self.conn.cursor()
except Exception as e: except Exception as e:
...@@ -89,7 +91,6 @@ class UPostgresDB: ...@@ -89,7 +91,6 @@ class UPostgresDB:
print(f"重新执行sql语句再次出现错误: {type(e).__name__}: {e}") print(f"重新执行sql语句再次出现错误: {type(e).__name__}: {e}")
self.conn.rollback() self.conn.rollback()
def search(self, query, params=None): def search(self, query, params=None):
if self.conn is None or self.conn.closed: if self.conn is None or self.conn.closed:
self.connect() self.connect()
......
from .c_db import UPostgresDB from .c_db import UPostgresDB
import json import json
TABLE_USER = """ TABLE_USER = """
DROP TABLE IF EXISTS "c_user"; DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user ( CREATE TABLE c_user (
...@@ -13,13 +14,14 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码'; ...@@ -13,13 +14,14 @@ COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表'; COMMENT ON TABLE "c_user" IS '用户表';
""" """
class CUser: class CUser:
def __init__(self, db: UPostgresDB) -> None: def __init__(self, db: UPostgresDB) -> None:
self.db = db self.db = db
def insert(self, value): def insert(self, value):
query = f"INSERT INTO c_user(user_id, account, password) VALUES (%s,%s,%s)" query = f"INSERT INTO c_user(user_id, account, password) VALUES (%s,%s,%s)"
self.db.execute_args(query, ((value[0],value[1],value[2]))) self.db.execute_args(query, (value[0], value[1], value[2]))
def create_table(self): def create_table(self):
query = TABLE_USER query = TABLE_USER
......
from .c_db import UPostgresDB from .c_db import UPostgresDB
import json import json
TABLE_CHAT = """ TABLE_CHAT = """
DROP TABLE IF EXISTS "chat"; DROP TABLE IF EXISTS "chat";
CREATE TABLE chat ( CREATE TABLE chat (
...@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是'; ...@@ -17,6 +18,7 @@ COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表'; COMMENT ON TABLE "chat" IS '会话信息表';
""" """
class Chat: class Chat:
def __init__(self, db: UPostgresDB) -> None: def __init__(self, db: UPostgresDB) -> None:
self.db = db self.db = db
...@@ -24,7 +26,7 @@ class Chat: ...@@ -24,7 +26,7 @@ class Chat:
# 插入数据 # 插入数据
def insert(self, value): def insert(self, value):
query = f"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (%s,%s,%s,%s)" query = f"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (%s,%s,%s,%s)"
self.db.execute_args(query, ((value[0],value[1],value[2],value[3]))) self.db.execute_args(query, (value[0], value[1], value[2], value[3]))
# 创建表 # 创建表
def create_table(self): def create_table(self):
......
from .c_db import UPostgresDB from .c_db import UPostgresDB
import json import json
TABLE_CHAT = """ TABLE_CHAT = """
DROP TABLE IF EXISTS "turn_qa"; DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa ( CREATE TABLE turn_qa (
...@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否, ...@@ -21,6 +22,7 @@ COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表'; COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
""" """
class TurnQa: class TurnQa:
def __init__(self, db: UPostgresDB) -> None: def __init__(self, db: UPostgresDB) -> None:
self.db = db self.db = db
...@@ -28,7 +30,7 @@ class TurnQa: ...@@ -28,7 +30,7 @@ class TurnQa:
# 插入数据 # 插入数据
def insert(self, value): def insert(self, value):
query = f"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s,%s)" query = f"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s,%s)"
self.db.execute_args(query, ((value[0],value[1],value[2],value[3],value[4],value[5]))) self.db.execute_args(query, (value[0], value[1], value[2], value[3], value[4], value[5]))
# 创建表 # 创建表
def create_table(self): def create_table(self):
......
...@@ -4,22 +4,24 @@ from os import path ...@@ -4,22 +4,24 @@ from os import path
sys.path.append("../") sys.path.append("../")
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import json import json
from typing import List,Any,Tuple,Dict from typing import List, Any, Tuple, Dict
from langchain.schema import Document from langchain.schema import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore,str2hash_base64 from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore, str2hash_base64
class DocumentCallback(ABC): class DocumentCallback(ABC):
@abstractmethod #向量库储存前文档处理-- @abstractmethod # 向量库储存前文档处理--
def before_store(self,docstore:PgSqlDocstore,documents): def before_store(self, docstore: PgSqlDocstore, documents):
pass pass
@abstractmethod #向量库查询后文档处理--用于结构建立
def after_search(self,docstore:PgSqlDocstore,documents:List[Tuple[Document, float]],number:int = 1000) -> List[Tuple[Document, float]]: #向量库查询后文档处理 @abstractmethod # 向量库查询后文档处理--用于结构建立
def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
List[Tuple[Document, float]]: # 向量库查询后文档处理
pass pass
class DefaultDocumentCallback(DocumentCallback): class DefaultDocumentCallback(DocumentCallback):
def before_store(self,docstore:PgSqlDocstore,documents): def before_store(self, docstore: PgSqlDocstore, documents):
output_doc = [] output_doc = []
for doc in documents: for doc in documents:
if "next_doc" in doc.metadata: if "next_doc" in doc.metadata:
...@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback): ...@@ -27,22 +29,24 @@ class DefaultDocumentCallback(DocumentCallback):
doc.metadata.pop("next_doc") doc.metadata.pop("next_doc")
output_doc.append(doc) output_doc.append(doc)
return output_doc return output_doc
def after_search(self,docstore:PgSqlDocstore,documents:List[Tuple[Document, float]],number:int = 1000) -> List[Tuple[Document, float]]: #向量库查询后文档处理
output_doc:List[Tuple[Document, float]] = [] def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
List[Tuple[Document, float]]: # 向量库查询后文档处理
output_doc: List[Tuple[Document, float]] = []
exist_hash = [] exist_hash = []
for doc,score in documents: for doc, score in documents:
print(exist_hash) print(exist_hash)
dochash = str2hash_base64(doc.page_content) dochash = str2hash_base64(doc.page_content)
if dochash in exist_hash: if dochash in exist_hash:
continue continue
else: else:
exist_hash.append(dochash) exist_hash.append(dochash)
output_doc.append((doc,score)) output_doc.append((doc, score))
if len(output_doc) > number: if len(output_doc) > number:
return output_doc return output_doc
fordoc = doc fordoc = doc
while ("next_hash" in fordoc.metadata): while "next_hash" in fordoc.metadata:
if len(fordoc.metadata["next_hash"])>0: if len(fordoc.metadata["next_hash"]) > 0:
if fordoc.metadata["next_hash"] in exist_hash: if fordoc.metadata["next_hash"] in exist_hash:
break break
else: else:
...@@ -50,7 +54,7 @@ class DefaultDocumentCallback(DocumentCallback): ...@@ -50,7 +54,7 @@ class DefaultDocumentCallback(DocumentCallback):
content = docstore.TXT_DOC.search(fordoc.metadata["next_hash"]) content = docstore.TXT_DOC.search(fordoc.metadata["next_hash"])
if content: if content:
fordoc = Document(page_content=content[0], metadata=json.loads(content[1])) fordoc = Document(page_content=content[0], metadata=json.loads(content[1]))
output_doc.append((fordoc,score)) output_doc.append((fordoc, score))
if len(output_doc) > number: if len(output_doc) > number:
return output_doc return output_doc
else: else:
......
import psycopg2 import psycopg2
class PostgresDB: class PostgresDB:
''' """
psycopg2.connect( psycopg2.connect(
dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。 dsn #指定连接参数。可以使用参数形式或 DSN 形式指定。
host #指定连接数据库的主机名。 host #指定连接数据库的主机名。
...@@ -17,8 +18,9 @@ class PostgresDB: ...@@ -17,8 +18,9 @@ class PostgresDB:
sslkey #指定私钥文件名。 sslkey #指定私钥文件名。
sslcert #指定公钥文件名。 sslcert #指定公钥文件名。
) )
''' """
def __init__(self, host, database, user, password,port = 5432):
def __init__(self, host, database, user, password, port=5432):
self.host = host self.host = host
self.database = database self.database = database
self.user = user self.user = user
...@@ -33,7 +35,7 @@ class PostgresDB: ...@@ -33,7 +35,7 @@ class PostgresDB:
database=self.database, database=self.database,
user=self.user, user=self.user,
password=self.password, password=self.password,
port = self.port port=self.port
) )
self.cur = self.conn.cursor() self.cur = self.conn.cursor()
...@@ -44,6 +46,7 @@ class PostgresDB: ...@@ -44,6 +46,7 @@ class PostgresDB:
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
self.conn.rollback() self.conn.rollback()
def execute_args(self, query, args): def execute_args(self, query, args):
try: try:
self.cur.execute(query, args) self.cur.execute(query, args)
......
import sys import sys
from os import path from os import path
# 这里相当于把当前目录添加到pythonpath中 # 这里相当于把当前目录添加到pythonpath中
sys.path.append(path.dirname(path.abspath(__file__))) sys.path.append(path.dirname(path.abspath(__file__)))
from typing import List,Union,Dict,Optional from typing import List, Union, Dict, Optional
from langchain.docstore.base import AddableMixin, Docstore from langchain.docstore.base import AddableMixin, Docstore
from k_db import PostgresDB from k_db import PostgresDB
from .txt_doc_table import TxtDoc from .txt_doc_table import TxtDoc
from .vec_txt_table import TxtVector from .vec_txt_table import TxtVector
import json,hashlib,base64 import json, hashlib, base64
from langchain.schema import Document from langchain.schema import Document
def str2hash_base64(inp: str) -> str:
def str2hash_base64(input:str) -> str:
# return f"%s" % hash(input) # return f"%s" % hash(input)
return base64.b64encode(hashlib.sha1(input.encode()).digest()).decode() return base64.b64encode(hashlib.sha1(inp.encode()).digest()).decode()
class PgSqlDocstore(Docstore,AddableMixin): class PgSqlDocstore(Docstore, AddableMixin):
host:str host: str
dbname:str dbname: str
username:str username: str
password:str password: str
port:str port: str
''' '''
说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。 说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
''' '''
def __getstate__(self): def __getstate__(self):
return {"host":self.host,"dbname":self.dbname,"username":self.username,"password":self.password,"port":self.port} return {"host": self.host, "dbname": self.dbname, "username": self.username, "password": self.password,
"port": self.port}
def __setstate__(self, info): def __setstate__(self, info):
self.__init__(info) self.__init__(info)
def __init__(self,info:dict,reset:bool = False): def __init__(self, info: dict, reset: bool = False):
self.host = info["host"] self.host = info["host"]
self.dbname = info["dbname"] self.dbname = info["dbname"]
self.username = info["username"] self.username = info["username"]
self.password = info["password"] self.password = info["password"]
self.port = info["port"] if "port" in info else "5432"; self.port = info["port"] if "port" in info else "5432"
self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password,port=self.port) self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password, port=self.port)
self.TXT_DOC = TxtDoc(self.pgdb) self.TXT_DOC = TxtDoc(self.pgdb)
self.VEC_TXT = TxtVector(self.pgdb) self.VEC_TXT = TxtVector(self.pgdb)
if reset: if reset:
...@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin): ...@@ -48,12 +50,15 @@ class PgSqlDocstore(Docstore,AddableMixin):
self.VEC_TXT.drop_table() self.VEC_TXT.drop_table()
self.TXT_DOC.create_table() self.TXT_DOC.create_table()
self.VEC_TXT.create_table() self.VEC_TXT.create_table()
def __sub_init__(self): def __sub_init__(self):
if not self.pgdb.conn: if not self.pgdb.conn:
self.pgdb.connect() self.pgdb.connect()
''' '''
从本地库中查找向量对应的文本段落,封装成Document返回 从本地库中查找向量对应的文本段落,封装成Document返回
''' '''
def search(self, search: str) -> Union[str, Document]: def search(self, search: str) -> Union[str, Document]:
if not self.pgdb.conn: if not self.pgdb.conn:
self.__sub_init__() self.__sub_init__()
...@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin): ...@@ -63,40 +68,44 @@ class PgSqlDocstore(Docstore,AddableMixin):
return Document(page_content=content[0], metadata=json.loads(content[1])) return Document(page_content=content[0], metadata=json.loads(content[1]))
else: else:
return Document() return Document()
''' '''
从本地库中删除向量对应的文本,批量删除 从本地库中删除向量对应的文本,批量删除
''' '''
def delete(self, ids: List) -> None: def delete(self, ids: List) -> None:
if not self.pgdb.conn: if not self.pgdb.conn:
self.__sub_init__() self.__sub_init__()
pids = [] pids = []
for id in ids: for item in ids:
anwser = self.VEC_TXT.search(id) anwser = self.VEC_TXT.search(item)
pids.append(anwser[0]) pids.append(anwser[0])
self.VEC_TXT.delete(ids) self.VEC_TXT.delete(ids)
self.TXT_DOC.delete(pids) self.TXT_DOC.delete(pids)
''' '''
向本地库添加向量和文本信息 向本地库添加向量和文本信息
[vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))] [vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
''' '''
def add(self, texts: Dict[str, Document]) -> None: def add(self, texts: Dict[str, Document]) -> None:
# for vec,doc in texts.items(): # for vec,doc in texts.items():
# paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"]) # paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
# self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content) # self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
if not self.pgdb.conn: if not self.pgdb.conn:
self.__sub_init__() self.__sub_init__()
paragraph_hashs = [] #hash,text paragraph_hashs = [] # hash,text
paragraph_txts = [] paragraph_txts = []
vec_inserts = [] vec_inserts = []
for vec,doc in texts.items(): for vec, doc in texts.items():
txt_hash = str2hash_base64(doc.metadata["paragraph"]) txt_hash = str2hash_base64(doc.metadata["paragraph"])
print(txt_hash) print(txt_hash)
vec_inserts.append((vec,doc.page_content,txt_hash)) vec_inserts.append((vec, doc.page_content, txt_hash))
if txt_hash not in paragraph_hashs: if txt_hash not in paragraph_hashs:
paragraph_hashs.append(txt_hash) paragraph_hashs.append(txt_hash)
paragraph = doc.metadata["paragraph"] paragraph = doc.metadata["paragraph"]
doc.metadata.pop("paragraph") doc.metadata.pop("paragraph")
paragraph_txts.append((txt_hash,paragraph,json.dumps(doc.metadata,ensure_ascii=False))) paragraph_txts.append((txt_hash, paragraph, json.dumps(doc.metadata, ensure_ascii=False)))
# print(paragraph_txts) # print(paragraph_txts)
self.TXT_DOC.insert(paragraph_txts) self.TXT_DOC.insert(paragraph_txts)
self.VEC_TXT.insert(vec_inserts) self.VEC_TXT.insert(vec_inserts)
...@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin): ...@@ -105,7 +114,7 @@ class PgSqlDocstore(Docstore,AddableMixin):
class InMemorySecondaryDocstore(Docstore, AddableMixin): class InMemorySecondaryDocstore(Docstore, AddableMixin):
"""Simple in memory docstore in the form of a dict.""" """Simple in memory docstore in the form of a dict."""
def __init__(self, _dict: Optional[Dict[str, Document]] = None,_sec_dict: Optional[Dict[str, Document]] = None): def __init__(self, _dict: Optional[Dict[str, Document]] = None, _sec_dict: Optional[Dict[str, Document]] = None):
"""Initialize with dict.""" """Initialize with dict."""
self._dict = _dict if _dict is not None else {} self._dict = _dict if _dict is not None else {}
self._sec_dict = _sec_dict if _sec_dict is not None else {} self._sec_dict = _sec_dict if _sec_dict is not None else {}
...@@ -126,14 +135,14 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin): ...@@ -126,14 +135,14 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
dict1 = {} dict1 = {}
dict_sec = {} dict_sec = {}
for vec,doc in texts.items(): for vec, doc in texts.items():
txt_hash = str2hash_base64(doc.metadata["paragraph"]) txt_hash = str2hash_base64(doc.metadata["paragraph"])
metadata=doc.metadata metadata = doc.metadata
paragraph = metadata.pop('paragraph') paragraph = metadata.pop('paragraph')
# metadata.update({"paragraph_id":txt_hash}) # metadata.update({"paragraph_id":txt_hash})
metadata['paragraph_id']=txt_hash metadata['paragraph_id'] = txt_hash
dict_sec[txt_hash] = Document(page_content=paragraph,metadata=metadata) dict_sec[txt_hash] = Document(page_content=paragraph, metadata=metadata)
dict1[vec] = Document(page_content=doc.page_content,metadata={'paragraph_id':txt_hash}) dict1[vec] = Document(page_content=doc.page_content, metadata={'paragraph_id': txt_hash})
self._dict = {**self._dict, **dict1} self._dict = {**self._dict, **dict1}
self._sec_dict = {**self._sec_dict, **dict_sec} self._sec_dict = {**self._sec_dict, **dict_sec}
...@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin): ...@@ -143,7 +152,7 @@ class InMemorySecondaryDocstore(Docstore, AddableMixin):
if not overlapping: if not overlapping:
raise ValueError(f"Tried to delete ids that does not exist: {ids}") raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids: for _id in ids:
self._sec_dict.pop(self._dict[id].metadata['paragraph_id']) self._sec_dict.pop(self._dict[_id].metadata['paragraph_id'])
self._dict.pop(_id) self._dict.pop(_id)
def search(self, search: str) -> Union[str, Document]: def search(self, search: str) -> Union[str, Document]:
......
import os, sys import os, sys
import re,time import re, time
from os import path from os import path
sys.path.append("../") sys.path.append("../")
import copy import copy
from typing import List,OrderedDict,Any,Optional,Tuple,Dict from typing import List, OrderedDict, Any, Optional, Tuple, Dict
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
from langchain.vectorstores.faiss import FAISS,dependable_faiss_import from langchain.vectorstores.faiss import FAISS, dependable_faiss_import
from langchain.schema import Document from langchain.schema import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore
from langchain.embeddings.huggingface import ( from langchain.embeddings.huggingface import (
...@@ -22,43 +22,54 @@ from langchain.callbacks.manager import ( ...@@ -22,43 +22,54 @@ from langchain.callbacks.manager import (
) )
from src.loader import load from src.loader import load
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from src.pgdb.knowledge.callback import DocumentCallback,DefaultDocumentCallback from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
def singleton(cls): def singleton(cls):
instances = {} instances = {}
def get_instance(*args, **kwargs): def get_instance(*args, **kwargs):
if cls not in instances: if cls not in instances:
instances[cls] = cls(*args, **kwargs) instances[cls] = cls(*args, **kwargs)
return instances[cls] return instances[cls]
return get_instance return get_instance
@singleton @singleton
class EmbeddingFactory: class EmbeddingFactory:
def __init__(self, path:str): def __init__(self, path: str):
self.path = path self.path = path
self.embedding = HuggingFaceEmbeddings(model_name=path) self.embedding = HuggingFaceEmbeddings(model_name=path)
def get_embedding(self): def get_embedding(self):
return self.embedding return self.embedding
def GetEmbding(path:str) -> Embeddings:
def GetEmbding(_path: str) -> Embeddings:
# return HuggingFaceEmbeddings(model_name=path) # return HuggingFaceEmbeddings(model_name=path)
return EmbeddingFactory(path).get_embedding() return EmbeddingFactory(_path).get_embedding()
import operator import operator
from langchain.vectorstores.utils import DistanceStrategy from langchain.vectorstores.utils import DistanceStrategy
import numpy as np import numpy as np
class RE_FAISS(FAISS): class RE_FAISS(FAISS):
#去重,并保留metadate # 去重,并保留metadate
def _tuple_deduplication(self, tuple_input:List[Tuple[Document, float]]) -> List[Tuple[Document, float]]: @staticmethod
def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
deduplicated_dict = OrderedDict() deduplicated_dict = OrderedDict()
for doc,scores in tuple_input: for doc, scores in tuple_input:
page_content = doc.page_content page_content = doc.page_content
metadata = doc.metadata metadata = doc.metadata
if page_content not in deduplicated_dict: if page_content not in deduplicated_dict:
deduplicated_dict[page_content] = (metadata,scores) deduplicated_dict[page_content] = (metadata, scores)
deduplicated_documents = [(Document(page_content=key,metadata=value[0]),value[1]) for key, value in deduplicated_dict.items()] deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
deduplicated_dict.items()]
return deduplicated_documents return deduplicated_documents
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, self,
embedding: List[float], embedding: List[float],
...@@ -107,8 +118,9 @@ class RE_FAISS(FAISS): ...@@ -107,8 +118,9 @@ class RE_FAISS(FAISS):
if "doc_callback" in kwargs: if "doc_callback" in kwargs:
if hasattr(kwargs["doc_callback"], 'after_search'): if hasattr(kwargs["doc_callback"], 'after_search'):
docs = kwargs["doc_callback"].after_search(self.docstore,docs,number=k) docs = kwargs["doc_callback"].after_search(self.docstore, docs, number=k)
return docs[:k] return docs[:k]
def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_by_vector(
self, self,
embedding: List[float], embedding: List[float],
...@@ -141,50 +153,61 @@ class RE_FAISS(FAISS): ...@@ -141,50 +153,61 @@ class RE_FAISS(FAISS):
docs_and_scores = self._tuple_deduplication(docs_and_scores) docs_and_scores = self._tuple_deduplication(docs_and_scores)
if "doc_callback" in kwargs: if "doc_callback" in kwargs:
if hasattr(kwargs["doc_callback"], 'after_search'): if hasattr(kwargs["doc_callback"], 'after_search'):
docs_and_scores = kwargs["doc_callback"].after_search(self.docstore,docs_and_scores,number=k) docs_and_scores = kwargs["doc_callback"].after_search(self.docstore, docs_and_scores, number=k)
return [doc for doc, _ in docs_and_scores] return [doc for doc, _ in docs_and_scores]
def getFAISS(embedding_model_name:str,store_path:str,info:dict = None,index_name:str = "index",is_pgsql:bool = True,reset:bool = False) -> RE_FAISS:
embeddings = GetEmbding(path=embedding_model_name) def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
docstore1:PgSqlDocstore = None is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
embeddings = GetEmbding(_path=embedding_model_name)
docstore1: PgSqlDocstore = None
if is_pgsql: if is_pgsql:
if info and "host" in info and "dbname" in info and "username" in info and "password" in info: if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
docstore1 = PgSqlDocstore(info,reset=reset) docstore1 = PgSqlDocstore(info, reset=reset)
else: else:
docstore1 = InMemorySecondaryDocstore() docstore1 = InMemorySecondaryDocstore()
if not path.exists(store_path): if not path.exists(store_path):
os.makedirs(store_path,exist_ok=True) os.makedirs(store_path, exist_ok=True)
if store_path is None or len(store_path) <= 0 or not path.exists(path.join(store_path,index_name+".faiss")) or reset: if store_path is None or len(store_path) <= 0 or not path.exists(
path.join(store_path, index_name + ".faiss")) or reset:
print("create new faiss") print("create new faiss")
index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0])) #根据embeddings向量维度设置 index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0])) # 根据embeddings向量维度设置
return RE_FAISS(embedding_function=embeddings.client.encode,index=index,docstore=docstore1,index_to_docstore_id={}) return RE_FAISS(embedding_function=embeddings.client.encode, index=index, docstore=docstore1,
index_to_docstore_id={})
else: else:
print("load_local faiss") print("load_local faiss")
_faiss = RE_FAISS.load_local(folder_path=store_path,index_name=index_name, embeddings=embeddings) _faiss = RE_FAISS.load_local(folder_path=store_path, index_name=index_name, embeddings=embeddings)
if docstore1 and is_pgsql: #如果外部参数调整,更新docstore if docstore1 and is_pgsql: # 如果外部参数调整,更新docstore
_faiss.docstore = docstore1 _faiss.docstore = docstore1
return _faiss return _faiss
class VectorStore_FAISS(FAISS): class VectorStore_FAISS(FAISS):
def __init__(self, embedding_model_name:str,store_path:str,index_name:str = "index",info:dict = None, is_pgsql:bool = True,show_number = 5, threshold = 0.8, reset:bool = False,doc_callback:DocumentCallback = DefaultDocumentCallback()): def __init__(self, embedding_model_name: str, store_path: str, index_name: str = "index", info: dict = None,
is_pgsql: bool = True, show_number=5, threshold=0.8, reset: bool = False,
doc_callback: DocumentCallback = DefaultDocumentCallback()):
self.info = info self.info = info
self.embedding_model_name = embedding_model_name self.embedding_model_name = embedding_model_name
self.store_path = path.join(store_path,index_name) self.store_path = path.join(store_path, index_name)
if not path.exists(self.store_path): if not path.exists(self.store_path):
os.makedirs(self.store_path,exist_ok=True) os.makedirs(self.store_path, exist_ok=True)
self.index_name = index_name self.index_name = index_name
self.show_number = show_number self.show_number = show_number
self.search_number = self.show_number*3 self.search_number = self.show_number * 3
self.threshold = threshold self.threshold = threshold
self._faiss = getFAISS(self.embedding_model_name,self.store_path,info=info,index_name=self.index_name,is_pgsql=is_pgsql,reset=reset) self._faiss = getFAISS(self.embedding_model_name, self.store_path, info=info, index_name=self.index_name,
is_pgsql=is_pgsql, reset=reset)
self.doc_callback = doc_callback self.doc_callback = doc_callback
def get_text_similarity_with_score(self, text:str,**kwargs): def get_text_similarity_with_score(self, text: str, **kwargs):
score_threshold = (1-self.threshold) * math.sqrt(2) score_threshold = (1 - self.threshold) * math.sqrt(2)
docs = self._faiss.similarity_search_with_score(query=text,k=self.search_number,score_threshold=score_threshold,doc_callback=self.doc_callback,**kwargs) docs = self._faiss.similarity_search_with_score(query=text, k=self.search_number,
score_threshold=score_threshold, doc_callback=self.doc_callback,
**kwargs)
return [doc for doc, similarity in docs][:self.show_number] return [doc for doc, similarity in docs][:self.show_number]
def get_text_similarity(self, text:str,**kwargs): def get_text_similarity(self, text: str, **kwargs):
docs = self._faiss.similarity_search(query=text,k=self.search_number,doc_callback=self.doc_callback,**kwargs) docs = self._faiss.similarity_search(query=text, k=self.search_number, doc_callback=self.doc_callback, **kwargs)
return docs[:self.show_number] return docs[:self.show_number]
# #去重,并保留metadate # #去重,并保留metadate
...@@ -199,22 +222,25 @@ class VectorStore_FAISS(FAISS): ...@@ -199,22 +222,25 @@ class VectorStore_FAISS(FAISS):
# deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()] # deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
# return deduplicated_documents # return deduplicated_documents
def _join_document(self, docs:List[Document]) -> str: @staticmethod
def _join_document(docs: List[Document]) -> str:
print(docs) print(docs)
return "".join([doc.page_content for doc in docs]) return "".join([doc.page_content for doc in docs])
def get_local_doc(self, docs:List[Document]): @staticmethod
def get_local_doc(docs: List[Document]):
ans = [] ans = []
for doc in docs: for doc in docs:
ans.append({"page_content":doc.page_content, "page_number":doc.metadata["page_number"], "filename":doc.metadata["filename"]}) ans.append({"page_content": doc.page_content, "page_number": doc.metadata["page_number"],
"filename": doc.metadata["filename"]})
return ans return ans
# def _join_document_location(self, docs:List[Document]) -> str: # def _join_document_location(self, docs:List[Document]) -> str:
# 持久化到本地 # 持久化到本地
def _save_local(self): def _save_local(self):
self._faiss.save_local(folder_path=self.store_path,index_name=self.index_name) self._faiss.save_local(folder_path=self.store_path, index_name=self.index_name)
# 添加文档 # 添加文档
# Document { # Document {
# page_content 段落 # page_content 段落
...@@ -222,10 +248,10 @@ class VectorStore_FAISS(FAISS): ...@@ -222,10 +248,10 @@ class VectorStore_FAISS(FAISS):
# page 页码 # page 页码
# } # }
# } # }
def _add_documents(self, new_docs:List[Document],need_split:bool = True,pattern:str = r'[?。;\n]'): def _add_documents(self, new_docs: List[Document], need_split: bool = True, pattern: str = r'[?。;\n]'):
list_of_documents:List[Document] = [] list_of_documents: List[Document] = []
if self.doc_callback: if self.doc_callback:
new_docs = self.doc_callback.before_store(self._faiss.docstore,new_docs) new_docs = self.doc_callback.before_store(self._faiss.docstore, new_docs)
if need_split: if need_split:
for doc in new_docs: for doc in new_docs:
words_list = re.split(pattern, doc.page_content) words_list = re.split(pattern, doc.page_content)
...@@ -240,8 +266,14 @@ class VectorStore_FAISS(FAISS): ...@@ -240,8 +266,14 @@ class VectorStore_FAISS(FAISS):
else: else:
list_of_documents = new_docs list_of_documents = new_docs
self._faiss.add_documents(list_of_documents) self._faiss.add_documents(list_of_documents)
def _add_documents_from_dir(self,filepaths = [],load_kwargs: Optional[dict] = {"mode":"paged"}):
self._add_documents(load.loads(filepaths,**load_kwargs)) def _add_documents_from_dir(self, filepaths=None, load_kwargs=None):
if load_kwargs is None:
load_kwargs = {"mode": "paged"}
if filepaths is None:
filepaths = []
self._add_documents(load.loads(filepaths, **load_kwargs))
def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever: def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
""" """
Return VectorStoreRetriever initialized from this VectorStore. Return VectorStoreRetriever initialized from this VectorStore.
...@@ -303,11 +335,11 @@ class VectorStore_FAISS(FAISS): ...@@ -303,11 +335,11 @@ class VectorStore_FAISS(FAISS):
default_kwargs.update(kwargs["search_kwargs"]) default_kwargs.update(kwargs["search_kwargs"])
kwargs["search_kwargs"] = default_kwargs kwargs["search_kwargs"] = default_kwargs
elif "similarity_score_threshold" == kwargs["search_type"]: elif "similarity_score_threshold" == kwargs["search_type"]:
default_kwargs = {'score_threshold': self.threshold,'k': self.show_number} default_kwargs = {'score_threshold': self.threshold, 'k': self.show_number}
if "search_kwargs" in kwargs: if "search_kwargs" in kwargs:
default_kwargs.update(kwargs["search_kwargs"]) default_kwargs.update(kwargs["search_kwargs"])
kwargs["search_kwargs"] = default_kwargs kwargs["search_kwargs"] = default_kwargs
kwargs["search_kwargs"]["doc_callback"]=self.doc_callback kwargs["search_kwargs"]["doc_callback"] = self.doc_callback
tags = kwargs.pop("tags", None) or [] tags = kwargs.pop("tags", None) or []
tags.extend(self._faiss._get_retriever_tags()) tags.extend(self._faiss._get_retriever_tags())
print(kwargs) print(kwargs)
...@@ -316,20 +348,21 @@ class VectorStore_FAISS(FAISS): ...@@ -316,20 +348,21 @@ class VectorStore_FAISS(FAISS):
class VectorStoreRetriever_FAISS(VectorStoreRetriever): class VectorStoreRetriever_FAISS(VectorStoreRetriever):
search_k = 5 search_k = 5
def __init__(self,**kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if "k" in self.search_kwargs: if "k" in self.search_kwargs:
self.search_k=self.search_kwargs["k"] self.search_k = self.search_kwargs["k"]
self.search_kwargs["k"]=self.search_k*2 self.search_kwargs["k"] = self.search_k * 2
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
docs = super()._get_relevant_documents(query=query,run_manager=run_manager) docs = super()._get_relevant_documents(query=query, run_manager=run_manager)
return docs[:self.search_k] return docs[:self.search_k]
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]: ) -> List[Document]:
docs = super()._aget_relevant_documents(query=query,run_manager=run_manager) docs = super()._aget_relevant_documents(query=query, run_manager=run_manager)
return docs[:self.search_k] return docs[:self.search_k]
\ No newline at end of file
from .k_db import PostgresDB from .k_db import PostgresDB
# paragraph_id BIGSERIAL primary key,
# paragraph_id BIGSERIAL primary key,
TABLE_TXT_DOC = """ TABLE_TXT_DOC = """
create table txt_doc ( create table txt_doc (
hash varchar(40) primary key, hash varchar(40) primary key,
...@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """ ...@@ -11,6 +12,8 @@ TABLE_TXT_DOC_HASH_INDEX = """
CREATE UNIQUE INDEX hash_index ON txt_doc (hash); CREATE UNIQUE INDEX hash_index ON txt_doc (hash);
""" """
# CREATE UNIQUE INDEX idx_name ON your_table (column_name); # CREATE UNIQUE INDEX idx_name ON your_table (column_name);
class TxtDoc: class TxtDoc:
def __init__(self, db: PostgresDB) -> None: def __init__(self, db: PostgresDB) -> None:
...@@ -21,19 +24,20 @@ class TxtDoc: ...@@ -21,19 +24,20 @@ class TxtDoc:
args = [] args = []
for value in texts: for value in texts:
value = list(value) value = list(value)
query+= "(%s,%s,%s)," query += "(%s,%s,%s),"
args.extend(value) args.extend(value)
query = query[:len(query)-1] query = query[:len(query) - 1]
query += f"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;" query += f"ON conflict(hash) DO UPDATE SET text = EXCLUDED.text;"
self.db.execute_args(query,args) self.db.execute_args(query, args)
def delete(self,ids): def delete(self, ids):
for id in ids: for item in ids:
query = f"delete FROM txt_doc WHERE hash = %s" % (id) query = f"delete FROM txt_doc WHERE hash = %s" % item
self.db.execute(query) self.db.execute(query)
def search(self, id):
def search(self, item):
query = "SELECT text,matadate FROM txt_doc WHERE hash = %s" query = "SELECT text,matadate FROM txt_doc WHERE hash = %s"
self.db.execute_args(query,[id]) self.db.execute_args(query, [item])
answer = self.db.fetchall() answer = self.db.fetchall()
if len(answer) > 0: if len(answer) > 0:
return answer[0] return answer[0]
...@@ -60,4 +64,3 @@ class TxtDoc: ...@@ -60,4 +64,3 @@ class TxtDoc:
query = "DROP TABLE txt_doc" query = "DROP TABLE txt_doc"
self.db.format(query) self.db.format(query)
print("drop table txt_doc ok") print("drop table txt_doc ok")
from .k_db import PostgresDB from .k_db import PostgresDB
TABLE_VEC_TXT = """ TABLE_VEC_TXT = """
CREATE TABLE vec_txt ( CREATE TABLE vec_txt (
vector_id varchar(36) PRIMARY KEY, vector_id varchar(36) PRIMARY KEY,
...@@ -6,7 +7,9 @@ CREATE TABLE vec_txt ( ...@@ -6,7 +7,9 @@ CREATE TABLE vec_txt (
paragraph_id varchar(40) not null paragraph_id varchar(40) not null
) )
""" """
#025a9bee-2eb2-47f5-9722-525e05a0442b
# 025a9bee-2eb2-47f5-9722-525e05a0442b
class TxtVector: class TxtVector:
def __init__(self, db: PostgresDB) -> None: def __init__(self, db: PostgresDB) -> None:
self.db = db self.db = db
...@@ -16,19 +19,21 @@ class TxtVector: ...@@ -16,19 +19,21 @@ class TxtVector:
args = [] args = []
for value in vectors: for value in vectors:
value = list(value) value = list(value)
query+= "(%s,%s,%s)," query += "(%s,%s,%s),"
args.extend(value) args.extend(value)
query = query[:len(query)-1] query = query[:len(query) - 1]
query += f"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;" query += f"ON conflict(vector_id) DO UPDATE SET text = EXCLUDED.text,paragraph_id = EXCLUDED.paragraph_id;"
# query += ";" # query += ";"
self.db.execute_args(query,args) self.db.execute_args(query, args)
def delete(self,ids):
for id in ids: def delete(self, ids):
query = f"delete FROM vec_txt WHERE vector_id = '%s'" % (id,) for item in ids:
query = f"delete FROM vec_txt WHERE vector_id = '%s'" % (item,)
self.db.execute(query) self.db.execute(query)
def search(self, search: str): def search(self, search: str):
query = f"SELECT paragraph_id,text FROM vec_txt WHERE vector_id = %s" query = f"SELECT paragraph_id,text FROM vec_txt WHERE vector_id = %s"
self.db.execute_args(query,[search]) self.db.execute_args(query, [search])
answer = self.db.fetchall() answer = self.db.fetchall()
print(answer) print(answer)
return answer[0] return answer[0]
......
import sys import sys
sys.path.append("../") sys.path.append("../")
from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.chat_table import Chat from src.pgdb.chat.chat_table import Chat
from src.pgdb.chat.c_user_table import CUser from src.pgdb.chat.c_user_table import CUser
from src.pgdb.chat.turn_qa_table import TurnQa from src.pgdb.chat.turn_qa_table import TurnQa
"""测试会话相关数据可的连接""" """测试会话相关数据可的连接"""
c_db = UPostgresDB(host="localhost", database="laechat", user="postgres", password="chenzl", port=5432)
chat = Chat(db=c_db)
c_user = CUser(db=c_db)
turn_qa = TurnQa(db=c_db)
chat.create_table()
c_user.create_table()
turn_qa.create_table()
# chat_id, user_id, info, deleted def test():
chat.insert(["3333", "1111", "没有info", 0]) c_db = UPostgresDB(host="localhost", database="laechat", user="postgres", password="chenzl", port=5432)
chat = Chat(db=c_db)
c_user = CUser(db=c_db)
turn_qa = TurnQa(db=c_db)
chat.create_table()
c_user.create_table()
turn_qa.create_table()
# chat_id, user_id, info, deleted
chat.insert(["3333", "1111", "没有info", 0])
# user_id, account, password
c_user.insert(["111", "zhangsan", "111111"])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0])
# user_id, account, password
c_user.insert(["111", "zhangsan", "111111"])
# turn_id, chat_id, question, answer, turn_number, is_last if __name__ == "main":
turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0]) test()
\ No newline at end of file
import sys import sys
sys.path.append("../") sys.path.append("../")
import time from src.loader.load import loads_path
from src.loader.load import loads_path,loads
from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import ( from src.config.consts import (
VEC_DB_DBNAME, VEC_DB_DBNAME,
...@@ -18,24 +18,27 @@ from src.config.consts import ( ...@@ -18,24 +18,27 @@ from src.config.consts import (
from src.loader.callback import BaseCallback from src.loader.callback import BaseCallback
# 当返回值中带有“思考题”字样的时候,默认将其忽略。 # 当返回值中带有“思考题”字样的时候,默认将其忽略。
class localCallback(BaseCallback): class localCallback(BaseCallback):
def filter(self,title:str,content:str) -> bool: def filter(self, title: str, content: str) -> bool:
if len(title+content) == 0: if len(title + content) == 0:
return True return True
return (len(title+content) / (len(title.splitlines())+len(content.splitlines())) < 20) or "思考题" in title return (len(title + content) / (len(title.splitlines()) + len(content.splitlines())) < 20) or "思考题" in title
"""测试资料入库(pgsql和faiss)""" """测试资料入库(pgsql和faiss)"""
def test_faiss_from_dir(): def test_faiss_from_dir():
vecstore_faiss = VectorStore_FAISS( vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH, embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH, store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME, index_name=INDEX_NAME,
info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD}, info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER, show_number=SIMILARITY_SHOW_NUMBER,
reset=True) reset=True)
docs = loads_path(KNOWLEDGE_PATH,mode="paged",sentence_size=512,callbacks=[localCallback()]) docs = loads_path(KNOWLEDGE_PATH, mode="paged", sentence_size=512, callbacks=[localCallback()])
print(len(docs)) print(len(docs))
last_doc = None last_doc = None
docs1 = [] docs1 = []
...@@ -45,7 +48,8 @@ def test_faiss_from_dir(): ...@@ -45,7 +48,8 @@ def test_faiss_from_dir():
continue continue
if "font-size" not in doc.metadata or "page_number" not in doc.metadata: if "font-size" not in doc.metadata or "page_number" not in doc.metadata:
continue continue
if doc.metadata["font-size"] == last_doc.metadata["font-size"] and doc.metadata["page_number"] == last_doc.metadata["page_number"] and len(doc.page_content)+len(last_doc.page_content) < 512/4*3: if doc.metadata["font-size"] == last_doc.metadata["font-size"] and doc.metadata["page_number"] == \
last_doc.metadata["page_number"] and len(doc.page_content) + len(last_doc.page_content) < 512 / 4 * 3:
last_doc.page_content += doc.page_content last_doc.page_content += doc.page_content
else: else:
docs1.append(last_doc) docs1.append(last_doc)
...@@ -56,17 +60,21 @@ def test_faiss_from_dir(): ...@@ -56,17 +60,21 @@ def test_faiss_from_dir():
print(len(docs)) print(len(docs))
print(vecstore_faiss._faiss.index.ntotal) print(vecstore_faiss._faiss.index.ntotal)
for i in range(0, len(docs), 300): for i in range(0, len(docs), 300):
vecstore_faiss._add_documents(docs[i:i+300 if i+300<len(docs) else len(docs)],need_split=True) vecstore_faiss._add_documents(docs[i:i + 300 if i + 300 < len(docs) else len(docs)], need_split=True)
print(vecstore_faiss._faiss.index.ntotal) print(vecstore_faiss._faiss.index.ntotal)
vecstore_faiss._save_local() vecstore_faiss._save_local()
"""测试faiss向量数据库查询结果""" """测试faiss向量数据库查询结果"""
def test_faiss_load(): def test_faiss_load():
vecstore_faiss = VectorStore_FAISS( vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH, embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH, store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME, index_name=INDEX_NAME,
info={"port":VEC_DB_PORT,"host":VEC_DB_HOST,"dbname":VEC_DB_DBNAME,"username":VEC_DB_USER,"password":VEC_DB_PASSWORD}, info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER, show_number=SIMILARITY_SHOW_NUMBER,
reset=False) reset=False)
print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("征信业务有什么情况"))) print(vecstore_faiss._join_document(vecstore_faiss.get_text_similarity("征信业务有什么情况")))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment