import os
import sys
import re
from os import path

import copy
from typing import List, OrderedDict, Any, Optional, Tuple, Dict
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
from langchain.vectorstores.faiss import FAISS
from langchain.schema import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore
from langchain.embeddings.huggingface import (
    HuggingFaceEmbeddings,
)
import math
import faiss
from langchain.vectorstores.utils import DistanceStrategy
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from src.loader import load
from langchain.embeddings.base import Embeddings
from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
import operator
from langchain.vectorstores.utils import DistanceStrategy
import numpy as np
sys.path.append("../")


def singleton(cls):
    instances = {}

    def get_instance(*args, **kwargs):
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]

    return get_instance


@singleton
class EmbeddingFactory:
    def __init__(self, _path: str):
        self.path = _path
        self.embedding = HuggingFaceEmbeddings(model_name=_path)

    def get_embedding(self):
        return self.embedding


def get_embding(_path: str) -> Embeddings:
    # return HuggingFaceEmbeddings(model_name=path)
    return EmbeddingFactory(_path).get_embedding()





class RE_FAISS(FAISS):
    # 去重,并保留metadate
    @staticmethod
    def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
        deduplicated_dict = OrderedDict()
        print("--------------oedereddict type--------------", type(deduplicated_dict))
        for doc, scores in tuple_input:
            page_content = doc.page_content
            metadata = doc.metadata
            if page_content not in deduplicated_dict:
                deduplicated_dict[page_content] = (metadata, scores)
        print("--------------------------du--------------------------\n", deduplicated_dict)
        deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
                                  deduplicated_dict.items()]
        return deduplicated_documents

    # def similarity_search_with_score_by_vector(
    #         self,
    #         embedding: List[float],
    #         k: int = 4,
    #         filter: Optional[Dict[str, Any]] = None,
    #         fetch_k: int = 20,
    #         **kwargs: Any,
    # ) -> List[Tuple[Document, float]]:
    #     faiss = dependable_faiss_import()
    #     vector = np.array([embedding], dtype=np.float32)
    #     if self._normalize_L2:
    #         faiss.normalize_L2(vector)
    #     scores, indices = self.index.search(vector, k if filter is None else fetch_k)
    #     docs = []
    #     for j, i in enumerate(indices[0]):
    #         if i == -1:
    #             # This happens when not enough docs are returned.
    #             continue
    #         _id = self.index_to_docstore_id[i]
    #         doc = self.docstore.search(_id)
    #         if not isinstance(doc, Document):
    #             raise ValueError(f"Could not find document for id {_id}, got {doc}")
    #         if filter is not None:
    #             filter = {
    #                 key: [value] if not isinstance(value, list) else value
    #                 for key, value in filter.items()
    #             }
    #             if all(doc.metadata.get(key) in value for key, value in filter.items()):
    #                 docs.append((doc, scores[0][j]))
    #         else:
    #             docs.append((doc, scores[0][j]))
    #     docs = self._tuple_deduplication(docs)
    #     score_threshold = kwargs.get("score_threshold")
    #     if score_threshold is not None:
    #         cmp = (
    #             operator.ge
    #             if self.distance_strategy
    #                in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
    #             else operator.le
    #         )
    #         docs = [
    #             (doc, similarity)
    #             for doc, similarity in docs
    #             if cmp(similarity, score_threshold)
    #         ]
    #
    #     if "doc_callback" in kwargs:
    #         if hasattr(kwargs["doc_callback"], 'after_search'):
    #             docs = kwargs["doc_callback"].after_search(self.docstore, docs, number=k)
    #     return docs[:k]

    def max_marginal_relevance_search_by_vector(
            self,
            embedding: List[float],
            k: int = 4,
            fetch_k: int = 20,
            lambda_mult: float = 0.5,
            filter: Optional[Dict[str, Any]] = None,
            **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
            embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
        )
        docs_and_scores = self._tuple_deduplication(docs_and_scores)
        if "doc_callback" in kwargs:
            if hasattr(kwargs["doc_callback"], 'after_search'):
                docs_and_scores = kwargs["doc_callback"].after_search(self.docstore, docs_and_scores, number=k)
        return [doc for doc, _ in docs_and_scores]


def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
             is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
    embeddings = get_embding(_path=embedding_model_name)
    docstore1: PgSqlDocstore = None
    if is_pgsql:
        if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
            docstore1 = PgSqlDocstore(info, reset=reset)
    else:
        docstore1 = InMemorySecondaryDocstore()
    if not path.exists(store_path):
        os.makedirs(store_path, exist_ok=True)
    if store_path is None or len(store_path) <= 0 or not path.exists(
            path.join(store_path, index_name + ".faiss")) or reset:
        print("create new faiss")
        index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0]))  # 根据embeddings向量维度设置
        return RE_FAISS(embedding_function=embeddings.client.encode, index=index, docstore=docstore1,
                        index_to_docstore_id={})
    else:
        print("load_local faiss")
        _faiss = RE_FAISS.load_local(folder_path=store_path, index_name=index_name, embeddings=embeddings, allow_dangerous_deserialization=True)
        if docstore1 and is_pgsql:  # 如果外部参数调整,更新docstore
            _faiss.docstore = docstore1
        return _faiss


class VectorStore_FAISS(FAISS):
    def __init__(self, embedding_model_name: str, store_path: str, index_name: str = "index", info: dict = None,
                 is_pgsql: bool = True, show_number=5, threshold=0.8, reset: bool = False,
                 doc_callback: DocumentCallback = DefaultDocumentCallback()):
        self.info = info
        self.embedding_model_name = embedding_model_name
        self.store_path = path.join(store_path, index_name)
        if not path.exists(self.store_path):
            os.makedirs(self.store_path, exist_ok=True)
        self.index_name = index_name
        self.show_number = show_number
        self.search_number = self.show_number * 3
        self.threshold = threshold
        self._faiss = getFAISS(self.embedding_model_name, self.store_path, info=info, index_name=self.index_name,
                               is_pgsql=is_pgsql, reset=reset)
        self.doc_callback = doc_callback

    def get_text_similarity_with_score(self, text: str, **kwargs):
        score_threshold = (1 - self.threshold) * math.sqrt(2)
        docs = self._faiss.similarity_search_with_score(query=text, k=self.search_number,
                                                        score_threshold=score_threshold, doc_callback=self.doc_callback,
                                                        **kwargs)
        return [doc for doc, similarity in docs][:self.show_number]

    def get_text_similarity(self, text: str, **kwargs):
        docs = self._faiss.similarity_search(query=text, k=self.search_number, doc_callback=self.doc_callback, **kwargs)
        return docs[:self.show_number]

    # #去重,并保留metadate
    # def _tuple_deduplication(self, tuple_input:List[Document]) -> List[Document]:
    #     deduplicated_dict = OrderedDict()
    #     for doc in tuple_input:
    #         page_content = doc.page_content
    #         metadata = doc.metadata
    #         if page_content not in deduplicated_dict:
    #             deduplicated_dict[page_content] = metadata

    #     deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
    #     return deduplicated_documents

    @staticmethod
    def join_document(docs: List[Document]) -> str:
        print(docs)
        return "".join([doc.page_content for doc in docs])

    @staticmethod
    def get_local_doc(docs: List[Document]):
        ans = []
        for doc in docs:
            ans.append({"page_content": doc.page_content, "page_number": doc.metadata["page_number"],
                        "filename": doc.metadata["filename"]})
        return ans

    # def _join_document_location(self, docs:List[Document]) -> str:

    # 持久化到本地
    def _save_local(self):
        self._faiss.save_local(folder_path=self.store_path, index_name=self.index_name)

    # 添加文档
    # Document {
    # page_content 段落
    # metadata {
    #    page 页码    
    #    }    
    # }
    def _add_documents(self, new_docs: List[Document], need_split: bool = True, pattern: str = r'[?。;\n]'):
        list_of_documents: List[Document] = []
        if self.doc_callback:
            new_docs = self.doc_callback.before_store(self._faiss.docstore, new_docs)
        if need_split:
            for doc in new_docs:
                words_list = re.split(pattern, doc.page_content)
                # 去掉重复项
                words_list = set(words_list)
                words_list = [str(words) for words in words_list]
                for words in words_list:
                    if not words.strip() == '':
                        metadata = copy.deepcopy(doc.metadata)
                        metadata["paragraph"] = doc.page_content
                        list_of_documents.append(Document(page_content=words, metadata=metadata))
        else:
            list_of_documents = new_docs
        self._faiss.add_documents(list_of_documents)

    def _add_documents_from_dir(self, filepaths=None, load_kwargs=None):
        if load_kwargs is None:
            load_kwargs = {"mode": "paged"}
        if filepaths is None:
            filepaths = []
        self._add_documents(load.loads(filepaths, **load_kwargs))

    def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
        """
        Return VectorStoreRetriever initialized from this VectorStore.

        Args:
            search_type (Optional[str]): Defines the type of search that
                the Retriever should perform.
                Can be "similarity" (default), "mmr", or
                "similarity_score_threshold".
            search_kwargs (Optional[Dict]): Keyword arguments to pass to the
                search function. Can include things like:
                    k: Amount of documents to return (Default: 4)
                    score_threshold: Minimum relevance threshold
                        for similarity_score_threshold
                    fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
                    lambda_mult: Diversity of results returned by MMR;
                        1 for minimum diversity and 0 for maximum. (Default: 0.5)
                    filter: Filter by document metadata

        Returns:
            VectorStoreRetriever: Retriever class for VectorStore.

        Examples:

        . code-block:: python

            # Retrieve more documents with higher diversity
            # Useful if your dataset has many similar documents
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 6, 'lambda_mult': 0.25}
            )

            # Fetch more documents for the MMR algorithm to consider
            # But only return the top 5
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 5, 'fetch_k': 50}
            )

            # Only retrieve documents that have a relevance score
            # Above a certain threshold
            docsearch.as_retriever(
                search_type="similarity_score_threshold",
                search_kwargs={'score_threshold': 0.8}
            )

            # Only get the single most similar document from the dataset
            docsearch.as_retriever(search_kwargs={'k': 1})

            # Use a filter to only retrieve documents from a specific paper
            docsearch.as_retriever(
                search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
            )
    """
        if not kwargs or "similarity_score_threshold" != kwargs["search_type"]:
            default_kwargs = {'k': self.show_number}
            if "search_kwargs" in kwargs:
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
        elif "similarity_score_threshold" == kwargs["search_type"]:
            default_kwargs = {'score_threshold': self.threshold, 'k': self.show_number}
            if "search_kwargs" in kwargs:
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
        kwargs["search_kwargs"]["doc_callback"] = self.doc_callback
        tags = kwargs.pop("tags", None) or []
        tags.extend(self._faiss._get_retriever_tags())
        print(kwargs)
        return VectorStoreRetriever_FAISS(vectorstore=self._faiss, **kwargs, tags=tags)


class VectorStoreRetriever_FAISS(VectorStoreRetriever):
    search_k = 5

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if "k" in self.search_kwargs:
            self.search_k = self.search_kwargs["k"]
            self.search_kwargs["k"] = self.search_k * 2

    def _get_relevant_documents(
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        docs = super()._get_relevant_documents(query=query, run_manager=run_manager)
        return docs[:self.search_k]

    async def _aget_relevant_documents(
            self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
    ) -> List[Document]:
        docs = super()._aget_relevant_documents(query=query, run_manager=run_manager)
        return docs[:self.search_k]