import sys
from os import path

# 这里相当于把当前目录添加到pythonpath中
sys.path.append(path.dirname(path.abspath(__file__)))

from typing import List, Union, Dict, Optional
from langchain_community.docstore.base import AddableMixin, Docstore
from k_db import PostgresDB
from .txt_doc_table import TxtDoc
from .vec_txt_table import TxtVector
import json, hashlib, base64
from langchain_core.documents import Document


def str2hash_base64(inp: str) -> str:
    # return f"%s" % hash(input)
    return base64.b64encode(hashlib.sha1(inp.encode()).digest()).decode()


class PgSqlDocstore(Docstore, AddableMixin):
    host: str
    dbname: str
    username: str
    password: str
    port: str
    '''
    说明,重写__getstate__,__setstate__,适用于langchain的序列化存储,基于pickle进行存储。返回数组包含pgsql连接信息。
    '''

    def __getstate__(self):
        return {"host": self.host, "dbname": self.dbname, "username": self.username, "password": self.password,
                "port": self.port}

    def __setstate__(self, info):
        self.__init__(info)

    def __init__(self, info: dict, reset: bool = False):
        self.host = info["host"]
        self.dbname = info["dbname"]
        self.username = info["username"]
        self.password = info["password"]
        self.port = info["port"] if "port" in info else "5432"
        self.pgdb = PostgresDB(self.host, self.dbname, self.username, self.password, port=self.port)
        self.TXT_DOC = TxtDoc(self.pgdb)
        self.VEC_TXT = TxtVector(self.pgdb)
        if reset:
            self.__sub_init__()
            self.TXT_DOC.drop_table()
            self.VEC_TXT.drop_table()
            self.TXT_DOC.create_table()
            self.VEC_TXT.create_table()

    def __sub_init__(self):
        if not self.pgdb.conn:
            self.pgdb.connect()

    '''
    从本地库中查找向量对应的文本段落,封装成Document返回
    '''

    def search(self, search: str) -> Union[str, Document]:
        if not self.pgdb.conn:
            self.__sub_init__()
        anwser = self.VEC_TXT.search(search)
        content = self.TXT_DOC.search(anwser[0])
        if content:
            return Document(page_content=content[0], metadata=json.loads(content[1]))
        else:
            return Document()

    '''
    从本地库中删除向量对应的文本,批量删除
    '''

    def delete(self, ids: List) -> None:
        if not self.pgdb.conn:
            self.__sub_init__()
        pids = []
        for item in ids:
            anwser = self.VEC_TXT.search(item)
            pids.append(anwser[0])
        self.VEC_TXT.delete(ids)
        self.TXT_DOC.delete(pids)

    '''
    向本地库添加向量和文本信息
    [vector_id,Document(page_content=问题, metadata=dict(paragraph=段落文本))]
    '''

    def add(self, texts: Dict[str, Document]) -> None:
        # for vec,doc in texts.items():
        #     paragraph_id = self.TXT_DOC.insert(doc.metadata["paragraph"])
        #     self.VEC_TXT.insert(vector_id=vec,paragraph_id=paragraph_id,text=doc.page_content)
        if not self.pgdb.conn:
            self.__sub_init__()
        paragraph_hashs = []  # hash,text
        paragraph_txts = []
        vec_inserts = []
        for vec, doc in texts.items():
            txt_hash = str2hash_base64(doc.metadata["paragraph"])
            print(txt_hash)
            vec_inserts.append((vec, doc.page_content, txt_hash))
            if txt_hash not in paragraph_hashs:
                paragraph_hashs.append(txt_hash)
                paragraph = doc.metadata["paragraph"]
                doc.metadata.pop("paragraph")
                paragraph_txts.append((txt_hash, paragraph, json.dumps(doc.metadata, ensure_ascii=False)))
        # print(paragraph_txts)
        self.TXT_DOC.insert(paragraph_txts)
        self.VEC_TXT.insert(vec_inserts)


class InMemorySecondaryDocstore(Docstore, AddableMixin):
    """Simple in memory docstore in the form of a dict."""

    def __init__(self, _dict: Optional[Dict[str, Document]] = None, _sec_dict: Optional[Dict[str, Document]] = None):
        """Initialize with dict."""
        self._dict = _dict if _dict is not None else {}
        self._sec_dict = _sec_dict if _sec_dict is not None else {}

    def add(self, texts: Dict[str, Document]) -> None:
        """Add texts to in memory dictionary.

        Args:
            texts: dictionary of id -> document.

        Returns:
            None
        """
        overlapping = set(texts).intersection(self._dict)
        if overlapping:
            raise ValueError(f"Tried to add ids that already exist: {overlapping}")
        self._dict = {**self._dict, **texts}

        dict1 = {}
        dict_sec = {}
        for vec, doc in texts.items():
            txt_hash = str2hash_base64(doc.metadata["paragraph"])
            metadata = doc.metadata
            paragraph = metadata.pop('paragraph')
            # metadata.update({"paragraph_id":txt_hash})
            metadata['paragraph_id'] = txt_hash
            dict_sec[txt_hash] = Document(page_content=paragraph, metadata=metadata)
            dict1[vec] = Document(page_content=doc.page_content, metadata={'paragraph_id': txt_hash})
        self._dict = {**self._dict, **dict1}
        self._sec_dict = {**self._sec_dict, **dict_sec}

    def delete(self, ids: List) -> None:
        """Deleting IDs from in memory dictionary."""
        overlapping = set(ids).intersection(self._dict)
        if not overlapping:
            raise ValueError(f"Tried to delete ids that does not  exist: {ids}")
        for _id in ids:
            self._sec_dict.pop(self._dict[_id].metadata['paragraph_id'])
            self._dict.pop(_id)

    def search(self, search: str) -> Union[str, Document]:
        """Search via direct lookup.

        Args:
            search: id of a document to search for.

        Returns:
            Document if found, else error message.
        """
        if search not in self._dict:
            return f"ID {search} not found."
        else:
            print(self._dict[search].page_content)
            return self._sec_dict[self._dict[search].metadata['paragraph_id']]