import sys

from abc import ABC, abstractmethod
import json
from typing import List, Tuple
from langchain_core.documents import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore, str2hash_base64

sys.path.append("../")


class DocumentCallback(ABC):
    @abstractmethod  # 向量库储存前文档处理--
    def before_store(self, docstore: PgSqlDocstore, documents):
        pass

    @abstractmethod  # 向量库查询后文档处理--用于结构建立
    def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
            List[Tuple[Document, float]]:  # 向量库查询后文档处理
        pass


class DefaultDocumentCallback(DocumentCallback):
    def before_store(self, docstore: PgSqlDocstore, documents):
        output_doc = []
        for doc in documents:
            if "next_doc" in doc.metadata:
                doc.metadata["next_hash"] = str2hash_base64(doc.metadata["next_doc"])
                doc.metadata.pop("next_doc")
            output_doc.append(doc)
        return output_doc

    def after_search(self, docstore: PgSqlDocstore, documents: List[Tuple[Document, float]], number: int = 1000) -> \
            List[Tuple[Document, float]]:  # 向量库查询后文档处理
        output_doc: List[Tuple[Document, float]] = []
        exist_hash = []
        for doc, score in documents:
            print(exist_hash)
            dochash = str2hash_base64(doc.page_content)
            if dochash in exist_hash:
                continue
            else:
                exist_hash.append(dochash)
            output_doc.append((doc, score))
            if len(output_doc) > number:
                return output_doc
            fordoc = doc
            while "next_hash" in fordoc.metadata:
                if len(fordoc.metadata["next_hash"]) > 0:
                    if fordoc.metadata["next_hash"] in exist_hash:
                        break
                    else:
                        exist_hash.append(fordoc.metadata["next_hash"])
                    content = docstore.TXT_DOC.search(fordoc.metadata["next_hash"])
                    if content:
                        fordoc = Document(page_content=content[0], metadata=json.loads(content[1]))
                        output_doc.append((fordoc, score))
                        if len(output_doc) > number:
                            return output_doc
                    else:
                        break
                else:
                    break
        return output_doc