# -*- coding: utf-8 -*-
import sys
import time
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
from src.server.get_similarity import GetSimilarity
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
    CHAT_DB_PASSWORD,
    EMBEEDING_MODEL_PATH,
    RERANK_MODEL_PATH,
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
    SIMILARITY_SHOW_NUMBER
)
from .rerank import BgeRerank

sys.path.append("../..")
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""




PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)



# 预设的安全响应
SAFE_RESPONSE = "您好,我不具备人类属性,因此没有名字。我可以协助您完成范围广泛的任务并提供有关各种主题的信息,比如回答问题,提是供定义和解释及建议。如果您有任何问题,请随时向我提问。"
BLOCKED_KEYWORDS = ["文心一言", "百度", "模型"]


class QA:
    def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db,rerank:bool=False):
        self.prompt = _prompt
        self.base_llm = _base_llm
        self.llm_kwargs = _llm_kwargs
        self.prompt_kwargs = _prompt_kwargs
        self.db = _db
        self.chat_id = None
        self.faiss_db = _faiss_db
        self.crud = CRUD(self.db)
        self.history = None
        self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
        self.cur_answer = ""
        self.cur_question = ""
        self.cur_similarity = ""
        self.cur_oquestion = ""
        self.rerank = rerank

        if rerank:
            self.rerank_model = BgeRerank(RERANK_MODEL_PATH)

    # 检查是否包含敏感信息
    def contains_blocked_keywords(self, text):
        return any(keyword in text for keyword in BLOCKED_KEYWORDS)

    # 为所给问题返回similarity文本
    def get_similarity(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
    
    def get_similarity_origin(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db)

    # 一次性直接给出所有的答案
    def chat(self, _question,with_similarity=False):
        self.cur_oquestion = _question
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
            similarity_docs = similarity.get_similarity_docs()
            rerank_docs = similarity.get_rerank_docs()
            self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
        if with_similarity:
            if self.rerank:
                return self.cur_answer, rerank_docs
            else:
                return self.cur_answer, similarity_docs
        return self.cur_answer

    def chat_with_history(self, _question,history,with_similarity=False):
        self.cur_oquestion = _question
        if self.contains_blocked_keywords(_question):
            self.cur_answer = SAFE_RESPONSE
        else:
            # self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
            similarity = self.get_similarity_origin(_aquestion=self.cur_oquestion)
            if self.rerank:
                self.cur_similarity = similarity.get_rerank(self.rerank_model)
            else:
                self.cur_similarity = similarity.get_similarity_doc()
            similarity_docs = similarity.get_similarity_docs()
            rerank_docs = similarity.get_rerank_docs()
            self.cur_question = self.prompt.format(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if not _question:
                return ""
            self.cur_answer = self.llm.run(history=history,context=self.cur_similarity, question=self.cur_oquestion)
            if self.contains_blocked_keywords(self.cur_answer):
                self.cur_answer = SAFE_RESPONSE

        self.update_history()
        if with_similarity:
            if self.rerank:
                return self.cur_answer, rerank_docs
            else:
                return self.cur_answer, similarity_docs
        return self.cur_answer


    # 异步输出,逐渐输出答案
    async def async_chat(self, _question):
        self.cur_oquestion = _question
        history = self.get_history()

        # if self.contains_blocked_keywords(_question):
        #     self.cur_answer = SAFE_RESPONSE
        #     yield [(self.cur_oquestion, self.cur_answer)]
            # return

        self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
        self.cur_question = self.prompt.format(context=self.cur_similarity, question=self.cur_oquestion)
        callback = AsyncIteratorCallbackHandler()

        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()

        task = asyncio.create_task(
            wrap_done(self.llm.arun(context=self.cur_similarity, question=self.cur_oquestion, callbacks=[callback]),
                      callback.done))

        self.cur_answer = ""
        history.append((self.cur_oquestion, self.cur_answer))
        async for token in callback.aiter():
            self.cur_answer += token
            if self.contains_blocked_keywords(self.cur_oquestion):
                self.cur_answer = SAFE_RESPONSE
                history[-1] = (self.cur_oquestion, self.cur_answer)
                yield history
                return
            history[-1] = (self.cur_oquestion, self.cur_answer)
            yield history
        await task

    async def async_chat2(self, history):
        _question = history[-1][0]
        history = history[:-1]

        # if self.contains_blocked_keywords(_question):
        #     self.cur_answer = SAFE_RESPONSE
        #     yield [(_question, self.cur_answer)]
        #     return

        self.cur_similarity = self.get_similarity(_aquestion=_question)
        self.cur_question = self.prompt.format(context=self.cur_similarity, question=_question)
        callback = AsyncIteratorCallbackHandler()

        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()

        task = asyncio.create_task(
            wrap_done(self.llm.arun(context=self.cur_similarity, question=_question, callbacks=[callback]),
                      callback.done))

        self.cur_answer = ""
        print(_question, self.cur_answer)
        history.append((_question, self.cur_answer))
        async for token in callback.aiter():
            self.cur_answer += token
            if self.contains_blocked_keywords(_question):
                self.cur_answer = SAFE_RESPONSE
                history[-1] = (_question, self.cur_answer)
                yield history
                return
            history[-1] = (_question, self.cur_answer)
            yield history
        await task

    def get_history(self):
        self.history = self.crud.get_history(self.chat_id)
        return self.history

    def update_history(self):
        if self.cur_oquestion and self.cur_answer:
            if not self.history:
                self.history = []
            # 避免重复添加条目
            self.history = [(q, a) for q, a in self.history if q != self.cur_oquestion or a != ""]
            self.history.append((self.cur_oquestion, self.cur_answer))
            self.crud.update_last(chat_id=self.chat_id)
            self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer,
                                     turn_number=len(self.history), is_last=1)

    def get_users(self):
        return self.crud.get_users()

    def get_chats(self, user_account):
        return self.crud.get_chats(user_account)

    def set_chat_id(self, chat_id):
        self.chat_id = chat_id
        self.history = self.crud.get_history(self.chat_id)

    def create_chat(self, user_account):
        user_id = self.crud.get_uersid_from_account(user_account)
        self.chat_id = self.crud.create_chat(user_id, '\t\t', '0')

    def set_info(self, question):
        info = self.crud.get_chat_info(self.chat_id)
        if info == '\t\t':
            if len(question) <= 10:
                n_info = question
                self.crud.set_info(self.chat_id, n_info)
            else:
                info_prompt = """'''
                        {question}
                        '''
                        请你用十个字之内总结上述问题,你的输出不得大于10个字。
                        """
                info_prompt_t = PromptTemplate(input_variables=["question"], template=info_prompt)
                info_llm = LLMChain(llm=self.base_llm, prompt=info_prompt_t, llm_kwargs=self.llm_kwargs)
                n_info = info_llm.run(question=question)
                self.crud.set_info(self.chat_id, n_info)


if __name__ == "__main__":
    # 数据库连接
    c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                       port=CHAT_DB_PORT, )
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
    vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
        store_path=FAISS_STORE_PATH,
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
    my_chat.set_chat_id('1')
    print(my_chat.get_history())