import sys from os import path # 这里相当于把当前目录添加到pythonpath中 sys.path.append(path.dirname(path.abspath(__file__))) from typing import List, Union, Dict, Optional from langchain_community.docstore.base import AddableMixin, Docstore from k_db import PostgresDB from .txt_doc_table import TxtDoc from .vec_txt_table import TxtVector import json, hashlib, base64 from langchain_core.documents import Document def str2hash_base64(inp: str) -> str: # return f"%s" % hash(input) return base64.b64encode(hashlib.sha1(inp.encode()).digest()).decode() class PgSqlDocstore(Docstore, AddableMixin): host: str dbname: str username: str password: str port: str ''' 说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。 ''' def __getstate__(self): return {"host": self.host, "dbname": self.dbname, "username": self.username, "password": self.password, "port": self.port} def __setstate__(self, info): self.__init__(info) def __init__(self, info: dict, reset: bool = False): self.host = info["host"] self.dbname = info["dbname"] self.username = info["username"] self.password = info["password"] self.port = info["port"] if "port" in info else "5432" self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password, port=self.port) self.TXT_DOC = TxtDoc(self.pgdb) self.VEC_TXT = TxtVector(self.pgdb) if reset: self.__sub_init__() self.TXT_DOC.drop_table() self.VEC_TXT.drop_table() self.TXT_DOC.create_table() self.VEC_TXT.create_table() def __sub_init__(self): if not self.pgdb.conn: self.pgdb.connect() ''' 从本地库中查找向量对应的文本段落,封装成Document返回 ''' def search(self, search: str) -> Union[str, Document]: if not self.pgdb.conn: self.__sub_init__() anwser = self.VEC_TXT.search(search) content = self.TXT_DOC.search(anwser[0]) if content: return Document(page_content=content[0], metadata=json.loads(content[1])) else: return Document() ''' 从本地库中删除向量对应的文本,批量删除 ''' def delete(self, ids: List) -> None: if not self.pgdb.conn: self.__sub_init__() pids = [] for item in ids: anwser = self.VEC_TXT.search(item) pids.append(anwser[0]) self.VEC_TXT.delete(ids) self.TXT_DOC.delete(pids) ''' 向本地库添加向量和文本信息 [vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))] ''' def add(self, texts: Dict[str, Document]) -> None: # for vec,doc in texts.items(): # paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"]) # self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content) if not self.pgdb.conn: self.__sub_init__() paragraph_hashs = [] # hash,text paragraph_txts = [] vec_inserts = [] for vec, doc in texts.items(): txt_hash = str2hash_base64(doc.metadata["paragraph"]) print(txt_hash) vec_inserts.append((vec, doc.page_content, txt_hash)) if txt_hash not in paragraph_hashs: paragraph_hashs.append(txt_hash) paragraph = doc.metadata["paragraph"] doc.metadata.pop("paragraph") paragraph_txts.append((txt_hash, paragraph, json.dumps(doc.metadata, ensure_ascii=False))) # print(paragraph_txts) self.TXT_DOC.insert(paragraph_txts) self.VEC_TXT.insert(vec_inserts) class InMemorySecondaryDocstore(Docstore, AddableMixin): """Simple in memory docstore in the form of a dict.""" def __init__(self, _dict: Optional[Dict[str, Document]] = None, _sec_dict: Optional[Dict[str, Document]] = None): """Initialize with dict.""" self._dict = _dict if _dict is not None else {} self._sec_dict = _sec_dict if _sec_dict is not None else {} def add(self, texts: Dict[str, Document]) -> None: """Add texts to in memory dictionary. Args: texts: dictionary of id -> document. Returns: None """ overlapping = set(texts).intersection(self._dict) if overlapping: raise ValueError(f"Tried to add ids that already exist: {overlapping}") self._dict = {**self._dict, **texts} dict1 = {} dict_sec = {} for vec, doc in texts.items(): txt_hash = str2hash_base64(doc.metadata["paragraph"]) metadata = doc.metadata paragraph = metadata.pop('paragraph') # metadata.update({"paragraph_id":txt_hash}) metadata['paragraph_id'] = txt_hash dict_sec[txt_hash] = Document(page_content=paragraph, metadata=metadata) dict1[vec] = Document(page_content=doc.page_content, metadata={'paragraph_id': txt_hash}) self._dict = {**self._dict, **dict1} self._sec_dict = {**self._sec_dict, **dict_sec} def delete(self, ids: List) -> None: """Deleting IDs from in memory dictionary.""" overlapping = set(ids).intersection(self._dict) if not overlapping: raise ValueError(f"Tried to delete ids that does not exist: {ids}") for _id in ids: self._sec_dict.pop(self._dict[_id].metadata['paragraph_id']) self._dict.pop(_id) def search(self, search: str) -> Union[str, Document]: """Search via direct lookup. Args: search: id of a document to search for. Returns: Document if found, else error message. """ if search not in self._dict: return f"ID {search} not found." else: print(self._dict[search].page_content) return self._sec_dict[self._dict[search].metadata['paragraph_id']]