import os
import sys
import re
from os import path

import copy
from typing import List, OrderedDict, Any, Optional, Tuple, Dict
from src.pgdb.knowledge.pgsqldocstore import InMemorySecondaryDocstore
from langchain.vectorstores.faiss import FAISS, dependable_faiss_import
from langchain.schema import Document
from src.pgdb.knowledge.pgsqldocstore import PgSqlDocstore
from langchain.embeddings.huggingface import (
    HuggingFaceEmbeddings,
)
import math
import faiss
from langchain.vectorstores.utils import DistanceStrategy
from langchain.vectorstores.base import VectorStoreRetriever
from langchain.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from src.loader import load
from langchain.embeddings.base import Embeddings
from src.pgdb.knowledge.callback import DocumentCallback, DefaultDocumentCallback
import operator
from langchain.vectorstores.utils import DistanceStrategy
import numpy as np
sys.path.append("../")


def singleton(cls):
    instances = {}

    def get_instance(*args, **kwargs):
        if cls not in instances:
            instances[cls] = cls(*args, **kwargs)
        return instances[cls]

    return get_instance


@singleton
class EmbeddingFactory:
    def __init__(self, _path: str):
        self.path = _path
        self.embedding = HuggingFaceEmbeddings(model_name=_path)

    def get_embedding(self):
        return self.embedding


def get_embding(_path: str) -> Embeddings:
    # return HuggingFaceEmbeddings(model_name=path)
    return EmbeddingFactory(_path).get_embedding()





class RE_FAISS(FAISS):
    # 去重,并保留metadate
    @staticmethod
    def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
        deduplicated_dict = OrderedDict()
        print("--------------oedereddict type--------------", type(deduplicated_dict))
        for doc, scores in tuple_input:
            page_content = doc.page_content
            metadata = doc.metadata
            if page_content not in deduplicated_dict:
                deduplicated_dict[page_content] = (metadata, scores)
        print("--------------------------du--------------------------\n", deduplicated_dict)
        deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
                                  deduplicated_dict.items()]
        return deduplicated_documents

    def similarity_search_with_score_by_vector(
            self,
            embedding: List[float],
            k: int = 4,
            filter: Optional[Dict[str, Any]] = None,
            fetch_k: int = 20,
            **kwargs: Any,
    ) -> List[Tuple[Document, float]]:
        faiss = dependable_faiss_import()
        vector = np.array([embedding], dtype=np.float32)
        if self._normalize_L2:
            faiss.normalize_L2(vector)
        scores, indices = self.index.search(vector, k if filter is None else fetch_k)
        docs = []
        for j, i in enumerate(indices[0]):
            if i == -1:
                # This happens when not enough docs are returned.
                continue
            _id = self.index_to_docstore_id[i]
            doc = self.docstore.search(_id)
            if not isinstance(doc, Document):
                raise ValueError(f"Could not find document for id {_id}, got {doc}")
            if filter is not None:
                filter = {
                    key: [value] if not isinstance(value, list) else value
                    for key, value in filter.items()
                }
                if all(doc.metadata.get(key) in value for key, value in filter.items()):
                    docs.append((doc, scores[0][j]))
            else:
                docs.append((doc, scores[0][j]))
        docs = self._tuple_deduplication(docs)
        score_threshold = kwargs.get("score_threshold")
        if score_threshold is not None:
            cmp = (
                operator.ge
                if self.distance_strategy
                   in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
                else operator.le
            )
            docs = [
                (doc, similarity)
                for doc, similarity in docs
                if cmp(similarity, score_threshold)
            ]

        if "doc_callback" in kwargs:
            if hasattr(kwargs["doc_callback"], 'after_search'):
                docs = kwargs["doc_callback"].after_search(self.docstore, docs, number=k)
        return docs[:k]

    def max_marginal_relevance_search_by_vector(
            self,
            embedding: List[float],
            k: int = 4,
            fetch_k: int = 20,
            lambda_mult: float = 0.5,
            filter: Optional[Dict[str, Any]] = None,
            **kwargs: Any,
    ) -> List[Document]:
        """Return docs selected using the maximal marginal relevance.

        Maximal marginal relevance optimizes for similarity to query AND diversity
        among selected documents.

        Args:
            embedding: Embedding to look up documents similar to.
            k: Number of Documents to return. Defaults to 4.
            fetch_k: Number of Documents to fetch before filtering to
                     pass to MMR algorithm.
            lambda_mult: Number between 0 and 1 that determines the degree
                        of diversity among the results with 0 corresponding
                        to maximum diversity and 1 to minimum diversity.
                        Defaults to 0.5.
        Returns:
            List of Documents selected by maximal marginal relevance.
        """
        docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
            embedding, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, filter=filter
        )
        docs_and_scores = self._tuple_deduplication(docs_and_scores)
        if "doc_callback" in kwargs:
            if hasattr(kwargs["doc_callback"], 'after_search'):
                docs_and_scores = kwargs["doc_callback"].after_search(self.docstore, docs_and_scores, number=k)
        return [doc for doc, _ in docs_and_scores]


def getFAISS(embedding_model_name: str, store_path: str, info: dict = None, index_name: str = "index",
             is_pgsql: bool = True, reset: bool = False) -> RE_FAISS:
    embeddings = get_embding(_path=embedding_model_name)
    docstore1: PgSqlDocstore = None
    if is_pgsql:
        if info and "host" in info and "dbname" in info and "username" in info and "password" in info:
            docstore1 = PgSqlDocstore(info, reset=reset)
    else:
        docstore1 = InMemorySecondaryDocstore()
    if not path.exists(store_path):
        os.makedirs(store_path, exist_ok=True)
    if store_path is None or len(store_path) <= 0 or not path.exists(
            path.join(store_path, index_name + ".faiss")) or reset:
        print("create new faiss")
        index = faiss.IndexFlatL2(len(embeddings.embed_documents(["a"])[0]))  # 根据embeddings向量维度设置
        return RE_FAISS(embedding_function=embeddings.client.encode, index=index, docstore=docstore1,
                        index_to_docstore_id={})
    else:
        print("load_local faiss")
        _faiss = RE_FAISS.load_local(folder_path=store_path, index_name=index_name, embeddings=embeddings)
        if docstore1 and is_pgsql:  # 如果外部参数调整,更新docstore
            _faiss.docstore = docstore1
        return _faiss


class VectorStore_FAISS(FAISS):
    def __init__(self, embedding_model_name: str, store_path: str, index_name: str = "index", info: dict = None,
                 is_pgsql: bool = True, show_number=5, threshold=0.8, reset: bool = False,
                 doc_callback: DocumentCallback = DefaultDocumentCallback()):
        self.info = info
        self.embedding_model_name = embedding_model_name
        self.store_path = path.join(store_path, index_name)
        if not path.exists(self.store_path):
            os.makedirs(self.store_path, exist_ok=True)
        self.index_name = index_name
        self.show_number = show_number
        self.search_number = self.show_number * 3
        self.threshold = threshold
        self._faiss = getFAISS(self.embedding_model_name, self.store_path, info=info, index_name=self.index_name,
                               is_pgsql=is_pgsql, reset=reset)
        self.doc_callback = doc_callback

    def get_text_similarity_with_score(self, text: str, **kwargs):
        score_threshold = (1 - self.threshold) * math.sqrt(2)
        docs = self._faiss.similarity_search_with_score(query=text, k=self.search_number,
                                                        score_threshold=score_threshold, doc_callback=self.doc_callback,
                                                        **kwargs)
        return [doc for doc, similarity in docs][:self.show_number]

    def get_text_similarity(self, text: str, **kwargs):
        docs = self._faiss.similarity_search(query=text, k=self.search_number, doc_callback=self.doc_callback, **kwargs)
        return docs[:self.show_number]

    # #去重,并保留metadate
    # def _tuple_deduplication(self, tuple_input:List[Document]) -> List[Document]:
    #     deduplicated_dict = OrderedDict()
    #     for doc in tuple_input:
    #         page_content = doc.page_content
    #         metadata = doc.metadata
    #         if page_content not in deduplicated_dict:
    #             deduplicated_dict[page_content] = metadata

    #     deduplicated_documents = [Document(page_content=key,metadata=value) for key, value in deduplicated_dict.items()]
    #     return deduplicated_documents

    @staticmethod
    def _join_document(docs: List[Document]) -> str:
        print(docs)
        return "".join([doc.page_content for doc in docs])

    @staticmethod
    def get_local_doc(docs: List[Document]):
        ans = []
        for doc in docs:
            ans.append({"page_content": doc.page_content, "page_number": doc.metadata["page_number"],
                        "filename": doc.metadata["filename"]})
        return ans

    # def _join_document_location(self, docs:List[Document]) -> str:

    # 持久化到本地
    def _save_local(self):
        self._faiss.save_local(folder_path=self.store_path, index_name=self.index_name)

    # 添加文档
    # Document {
    # page_content 段落
    # metadata {
    #    page 页码    
    #    }    
    # }
    def _add_documents(self, new_docs: List[Document], need_split: bool = True, pattern: str = r'[?。;\n]'):
        list_of_documents: List[Document] = []
        if self.doc_callback:
            new_docs = self.doc_callback.before_store(self._faiss.docstore, new_docs)
        if need_split:
            for doc in new_docs:
                words_list = re.split(pattern, doc.page_content)
                # 去掉重复项
                words_list = set(words_list)
                words_list = [str(words) for words in words_list]
                for words in words_list:
                    if not words.strip() == '':
                        metadata = copy.deepcopy(doc.metadata)
                        metadata["paragraph"] = doc.page_content
                        list_of_documents.append(Document(page_content=words, metadata=metadata))
        else:
            list_of_documents = new_docs
        self._faiss.add_documents(list_of_documents)

    def _add_documents_from_dir(self, filepaths=None, load_kwargs=None):
        if load_kwargs is None:
            load_kwargs = {"mode": "paged"}
        if filepaths is None:
            filepaths = []
        self._add_documents(load.loads(filepaths, **load_kwargs))

    def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
        """
        Return VectorStoreRetriever initialized from this VectorStore.

        Args:
            search_type (Optional[str]): Defines the type of search that
                the Retriever should perform.
                Can be "similarity" (default), "mmr", or
                "similarity_score_threshold".
            search_kwargs (Optional[Dict]): Keyword arguments to pass to the
                search function. Can include things like:
                    k: Amount of documents to return (Default: 4)
                    score_threshold: Minimum relevance threshold
                        for similarity_score_threshold
                    fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
                    lambda_mult: Diversity of results returned by MMR;
                        1 for minimum diversity and 0 for maximum. (Default: 0.5)
                    filter: Filter by document metadata

        Returns:
            VectorStoreRetriever: Retriever class for VectorStore.

        Examples:

        . code-block:: python

            # Retrieve more documents with higher diversity
            # Useful if your dataset has many similar documents
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 6, 'lambda_mult': 0.25}
            )

            # Fetch more documents for the MMR algorithm to consider
            # But only return the top 5
            docsearch.as_retriever(
                search_type="mmr",
                search_kwargs={'k': 5, 'fetch_k': 50}
            )

            # Only retrieve documents that have a relevance score
            # Above a certain threshold
            docsearch.as_retriever(
                search_type="similarity_score_threshold",
                search_kwargs={'score_threshold': 0.8}
            )

            # Only get the single most similar document from the dataset
            docsearch.as_retriever(search_kwargs={'k': 1})

            # Use a filter to only retrieve documents from a specific paper
            docsearch.as_retriever(
                search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
            )
    """
        if not kwargs or "similarity_score_threshold" != kwargs["search_type"]:
            default_kwargs = {'k': self.show_number}
            if "search_kwargs" in kwargs:
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
        elif "similarity_score_threshold" == kwargs["search_type"]:
            default_kwargs = {'score_threshold': self.threshold, 'k': self.show_number}
            if "search_kwargs" in kwargs:
                default_kwargs.update(kwargs["search_kwargs"])
            kwargs["search_kwargs"] = default_kwargs
        kwargs["search_kwargs"]["doc_callback"] = self.doc_callback
        tags = kwargs.pop("tags", None) or []
        tags.extend(self._faiss._get_retriever_tags())
        print(kwargs)
        return VectorStoreRetriever_FAISS(vectorstore=self._faiss, **kwargs, tags=tags)


class VectorStoreRetriever_FAISS(VectorStoreRetriever):
    search_k = 5

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if "k" in self.search_kwargs:
            self.search_k = self.search_kwargs["k"]
            self.search_kwargs["k"] = self.search_k * 2

    def _get_relevant_documents(
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        docs = super()._get_relevant_documents(query=query, run_manager=run_manager)
        return docs[:self.search_k]

    async def _aget_relevant_documents(
            self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
    ) -> List[Document]:
        docs = super()._aget_relevant_documents(query=query, run_manager=run_manager)
        return docs[:self.search_k]