pgsqldocstore.py 6.07 KB
Newer Older
1 2
import sys
from os import path
陈正乐 committed
3

4 5 6
# 这里相当于把当前目录添加到pythonpath中
sys.path.append(path.dirname(path.abspath(__file__)))

陈正乐 committed
7
from typing import List, Union, Dict, Optional
8
from langchain_community.docstore.base import AddableMixin, Docstore
9 10 11
from k_db import PostgresDB
from .txt_doc_table import TxtDoc
from .vec_txt_table import TxtVector
陈正乐 committed
12
import json, hashlib, base64
13
from langchain_core.documents import Document
14 15


陈正乐 committed
16
def str2hash_base64(inp: str) -> str:
17
    # return f"%s" % hash(input)
陈正乐 committed
18
    return base64.b64encode(hashlib.sha1(inp.encode()).digest()).decode()
19 20


陈正乐 committed
21 22 23 24 25 26
class PgSqlDocstore(Docstore, AddableMixin):
    host: str
    dbname: str
    username: str
    password: str
    port: str
27 28 29
    '''
    说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
    '''
陈正乐 committed
30

31
    def __getstate__(self):
陈正乐 committed
32 33
        return {"host": self.host, "dbname": self.dbname, "username": self.username, "password": self.password,
                "port": self.port}
34 35 36

    def __setstate__(self, info):
        self.__init__(info)
陈正乐 committed
37 38

    def __init__(self, info: dict, reset: bool = False):
39 40 41 42
        self.host = info["host"]
        self.dbname = info["dbname"]
        self.username = info["username"]
        self.password = info["password"]
陈正乐 committed
43 44
        self.port = info["port"] if "port" in info else "5432"
        self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password, port=self.port)
45 46 47 48 49 50 51 52
        self.TXT_DOC = TxtDoc(self.pgdb)
        self.VEC_TXT = TxtVector(self.pgdb)
        if reset:
            self.__sub_init__()
            self.TXT_DOC.drop_table()
            self.VEC_TXT.drop_table()
            self.TXT_DOC.create_table()
            self.VEC_TXT.create_table()
陈正乐 committed
53

54 55 56
    def __sub_init__(self):
        if not self.pgdb.conn:
            self.pgdb.connect()
陈正乐 committed
57

58 59 60
    '''
    从本地库中查找向量对应的文本段落,封装成Document返回
    '''
陈正乐 committed
61

62 63 64 65 66 67 68 69 70
    def search(self, search: str) -> Union[str, Document]:
        if not self.pgdb.conn:
            self.__sub_init__()
        anwser = self.VEC_TXT.search(search)
        content = self.TXT_DOC.search(anwser[0])
        if content:
            return Document(page_content=content[0], metadata=json.loads(content[1]))
        else:
            return Document()
陈正乐 committed
71

72 73 74
    '''
    从本地库中删除向量对应的文本,批量删除
    '''
陈正乐 committed
75

76 77 78 79
    def delete(self, ids: List) -> None:
        if not self.pgdb.conn:
            self.__sub_init__()
        pids = []
陈正乐 committed
80 81
        for item in ids:
            anwser = self.VEC_TXT.search(item)
82 83 84
            pids.append(anwser[0])
        self.VEC_TXT.delete(ids)
        self.TXT_DOC.delete(pids)
陈正乐 committed
85

86 87 88 89
    '''
    向本地库添加向量和文本信息
    [vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
    '''
陈正乐 committed
90

91 92 93 94 95 96
    def add(self, texts: Dict[str, Document]) -> None:
        # for vec,doc in texts.items():
        #     paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
        #     self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
        if not self.pgdb.conn:
            self.__sub_init__()
陈正乐 committed
97
        paragraph_hashs = []  # hash,text
98 99
        paragraph_txts = []
        vec_inserts = []
陈正乐 committed
100
        for vec, doc in texts.items():
101 102
            txt_hash = str2hash_base64(doc.metadata["paragraph"])
            print(txt_hash)
陈正乐 committed
103
            vec_inserts.append((vec, doc.page_content, txt_hash))
104 105 106 107
            if txt_hash not in paragraph_hashs:
                paragraph_hashs.append(txt_hash)
                paragraph = doc.metadata["paragraph"]
                doc.metadata.pop("paragraph")
陈正乐 committed
108
                paragraph_txts.append((txt_hash, paragraph, json.dumps(doc.metadata, ensure_ascii=False)))
109 110 111 112 113 114 115 116
        # print(paragraph_txts)
        self.TXT_DOC.insert(paragraph_txts)
        self.VEC_TXT.insert(vec_inserts)


class InMemorySecondaryDocstore(Docstore, AddableMixin):
    """Simple in memory docstore in the form of a dict."""

陈正乐 committed
117
    def __init__(self, _dict: Optional[Dict[str, Document]] = None, _sec_dict: Optional[Dict[str, Document]] = None):
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
        """Initialize with dict."""
        self._dict = _dict if _dict is not None else {}
        self._sec_dict = _sec_dict if _sec_dict is not None else {}

    def add(self, texts: Dict[str, Document]) -> None:
        """Add texts to in memory dictionary.

        Args:
            texts: dictionary of id -> document.

        Returns:
            None
        """
        overlapping = set(texts).intersection(self._dict)
        if overlapping:
            raise ValueError(f"Tried to add ids that already exist: {overlapping}")
        self._dict = {**self._dict, **texts}
陈正乐 committed
135

136 137
        dict1 = {}
        dict_sec = {}
陈正乐 committed
138
        for vec, doc in texts.items():
139
            txt_hash = str2hash_base64(doc.metadata["paragraph"])
陈正乐 committed
140
            metadata = doc.metadata
141 142
            paragraph = metadata.pop('paragraph')
            # metadata.update({"paragraph_id":txt_hash})
陈正乐 committed
143 144 145 146 147
            metadata['paragraph_id'] = txt_hash
            dict_sec[txt_hash] = Document(page_content=paragraph, metadata=metadata)
            dict1[vec] = Document(page_content=doc.page_content, metadata={'paragraph_id': txt_hash})
        self._dict = {**self._dict, **dict1}
        self._sec_dict = {**self._sec_dict, **dict_sec}
148 149 150 151 152 153 154

    def delete(self, ids: List) -> None:
        """Deleting IDs from in memory dictionary."""
        overlapping = set(ids).intersection(self._dict)
        if not overlapping:
            raise ValueError(f"Tried to delete ids that does not  exist: {ids}")
        for _id in ids:
陈正乐 committed
155
            self._sec_dict.pop(self._dict[_id].metadata['paragraph_id'])
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
            self._dict.pop(_id)

    def search(self, search: str) -> Union[str, Document]:
        """Search via direct lookup.

        Args:
            search: id of a document to search for.

        Returns:
            Document if found, else error message.
        """
        if search not in self._dict:
            return f"ID {search} not found."
        else:
            print(self._dict[search].page_content)
陈正乐 committed
171
            return self._sec_dict[self._dict[search].metadata['paragraph_id']]