from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.prompts import PROMPT_QUERY_EXTEND,PROMPT_QA_EXTEND_QUESTION
from src.server.rerank import BgeRerank,reciprocal_rank_fusion
import time
from typing import List
from langchain_core.documents import Document

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers.list import ListOutputParser


class GetSimilarity:
    def __init__(self, _question, _faiss_db: VectorStore_FAISS):
        self.question = _question
        self.faiss_db = _faiss_db
        self.similarity_docs = self.faiss_db.get_text_similarity(self.question)
        self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
        self.rerank_docs = ""

    def get_rerank(self,reranker:BgeRerank ,top_k = 5):
        rerank_docs = reranker.compress_documents(self.similarity_docs,self.question)
        d_list = []
        for d in rerank_docs[:top_k]:
            d_list.append(d)
        self.rerank_docs = rerank_docs[:top_k]
        return self.faiss_db.join_document(d_list)

    def get_similarity_doc(self):
        return self.similarity_doc_txt
    
    def get_similarity_docs(self):
        return self.similarity_docs
    
    def get_rerank_docs(self):
        return self.rerank_docs


class GetSimilarityWithExt:
    def __init__(self, _question, _faiss_db: VectorStore_FAISS, _location=None):
        self.question = _question
        self.faiss_db = _faiss_db
        self.location = _location
        self.similarity_docs = self.get_text_similarity_with_ext()
        self.similarity_doc_txt = self.faiss_db.join_document(self.similarity_docs)
        self.rerank_docs = []



    def get_rerank(self, reranker: BgeRerank, top_k=5):
        question = '\n'.join(self.question)
        print(question)
        rerank_docs = reranker.compress_documents(self.similarity_docs, question)
        d_list = []
        for d in rerank_docs[:top_k]:
            d_list.append(d)
        self.rerank_docs = rerank_docs[:top_k]
        return self.faiss_db.join_document(d_list)

    def join_document(self, docs: List[Document]) -> str:
        if len(docs) == 0:
            return "[]"
        result = "["
        d1 = docs[0]
        page_number = "0"
        if "page_number" in d1.metadata:
            page_number = str(d1.metadata["page_number"])
        hash_str = ""
        if "hash" in d1.metadata:
            hash_str = d1.metadata["hash"]
        result += ("{\"page_content\": \"" + d1.page_content + "\",\"from_file\":\"" + d1.metadata["filename"] + "\",\"hash\":\"" + hash_str + "\",\"page_number\":\"" + page_number + "\"}")
        docs = docs[1:]
        for doc in docs:
            page_number = "0"
            if "page_number" in doc.metadata:
                page_number = str(doc.metadata["page_number"])
            hash_str = ""
            if "hash" in doc.metadata:
                hash_str = doc.metadata["hash"]
            result += (",{\"page_content\": \"" + doc.page_content + "\",\"from_file\":\"" + doc.metadata["filename"] + "\",\"hash\":\"" + hash_str + "\",\"page_number\":\"" + page_number + "\"}")
        result += "]"
        return result

    def get_rerank_with_doc(self, reranker: BgeRerank,split_docs_list:list):
        # top_k = self.get_doc_nums(len(split_docs_list))
        question = '\n'.join(self.question)
        print(question)
        # rerank_docs1_hash = []
        # rerank_docs2_hash = []
        # m = {}
        result = []
        for split_doc in split_docs_list:
            start = time.time()
            rerank_docs1 = reranker.compress_documents(split_doc, question)
            result.extend(rerank_docs1[:3])
            end = time.time()
            print('重排1 time: %s Seconds' % (end - start))
            # for doc in rerank_docs1:
            #     m[hash(doc.page_content)] = doc
            #     rerank_docs1_hash.append(hash(doc.page_content))
            # result.append((60, rerank_docs1_hash))
        start = time.time()
        rerank_docs2 = reranker.compress_documents(self.similarity_docs, question)
        result.extend(rerank_docs2[:3])
        end = time.time()
        print('重排2 time: %s Seconds' % (end - start))
        # for doc in rerank_docs2:
        #     m[hash(doc.page_content)] = doc
        #     rerank_docs2_hash.append(hash(doc.page_content))
        #
        # result.append((55,rerank_docs2_hash))
        # print(len(rerank_docs1_hash))
        # print(len(rerank_docs2_hash))
        # start = time.time()
        # rrf_doc = reciprocal_rank_fusion(result)
        # end = time.time()
        # print('混排 time: %s Seconds' % (end - start))
        # print("混排文档数量:", len(rrf_doc))
        # d_list = []
        # for key in rrf_doc:
        #     d_list.append(m[key])
        # print("返回文档数量:",top_k)
        self.rerank_docs = result
        return self.join_document(result)

    def get_similarity_doc(self):
        return self.similarity_doc_txt

    def get_similarity_docs(self):
        return self.similarity_docs

    def get_rerank_docs(self):
        return self.rerank_docs

    def get_text_similarity_with_ext(self):
        similarity_docs = []
        for q in self.question:
            similarity_doc = self.faiss_db.get_text_similarity(q)
            similarity_docs.extend(similarity_doc)
        content_set = set()
        unique_documents = []
        for doc in similarity_docs:
            if self.location is not None:
                if any(substring in doc.page_content for substring in self.location) or any(substring in doc.metadata["filename"] for substring in self.location) :
                    content = hash(doc.page_content)
                    if content not in content_set:
                        unique_documents.append(doc)
                        content_set.add(content)
        print(len(unique_documents))
        return unique_documents

    def get_doc_nums(self,num :int)->int:
        num = num*3
        if num<5:
            return 5
        elif num>30:
            return 30
        else:
            return num

class QAExt:
    llm = None

    def __init__(self, llm) -> None:
        self.llm = llm
        prompt = PromptTemplate.from_template(PROMPT_QUERY_EXTEND)
        # parser = ListOutputParser()
        self.query_extend = prompt | llm 

    def extend_query(self, question, messages=None):
        """
            question: str
            messages: list of tuple (str,str)
                eg:
                [
                    ("Q1","A1"),
                    ("Q2","A2"),
                    ...
                ]
            
        """
        if not messages:
            messages = []
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories":messages, "query":question})

    def extend_query_with_str(self, question, messages):
        """
            question: str
            messages: list of tuple (str,str)
                eg:
                [
                    ("Q1","A1"),
                    ("Q2","A2"),
                    ...
                ]

        """
        return self.query_extend.invoke(input={"histories": messages, "query": question})

class ChatExtend:
    def __init__(self, llm) -> None:
        self.llm = llm
        prompt = PromptTemplate.from_template(PROMPT_QA_EXTEND_QUESTION)
        self.query_extend = prompt | llm

    def new_questions(self, messages):
        history = ""
        for msg in messages:
            history += f"Q: {msg[0]}\nA: {msg[1]}\n"
        return self.query_extend.invoke(input={"histories":history})