# -*- coding: utf-8 -*-
import sys
import time
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
from src.server.get_similarity import GetSimilarity
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
    CHAT_DB_PASSWORD,
    EMBEEDING_MODEL_PATH,
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
    SIMILARITY_SHOW_NUMBER
)

sys.path.append("../..")
prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)


class QA:
    def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id, _faiss_db):
        self.prompt = _prompt
        self.base_llm = _base_llm
        self.llm_kwargs = _llm_kwargs
        self.prompt_kwargs = _prompt_kwargs
        self.db = _db
        self.chat_id = _chat_id
        self.faiss_db = _faiss_db
        self.crud = CRUD(self.db)
        self.history = self.crud.get_history(self.chat_id)
        self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
        self.cur_answer = ""
        self.cur_question = ""
        self.cur_similarity = ""
        self.cur_oquestion = ""

    # 为所给问题返回similarity文本
    def get_similarity(self, _aquestion):
        return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()

    # 一次性直接给出所有的答案
    def chat(self, _question):
        self.cur_oquestion = _question
        self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
        self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))})
        self.cur_answer = ""
        if not _question:
            return ""
        self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))})
        return self.cur_answer

    # 异步输出,逐渐输出答案
    async def async_chat(self, _question):
        self.cur_oquestion = _question
        self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
        self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))})
        callback = AsyncIteratorCallbackHandler()

        async def wrap_done(fn: Awaitable, event: asyncio.Event):
            try:
                await fn
            except Exception as e:
                import traceback
                traceback.print_exc()
                print(f"Caught exception: {e}")
            finally:
                event.set()

        task = asyncio.create_task(
            wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}, callbacks=[callback]),
                      callback.done))
        self.cur_answer = ""
        async for token in callback.aiter():
            self.cur_answer = self.cur_answer + token
            yield f"{self.cur_answer}"
        await task

    def get_history(self):
        return self.history

    def update_history(self):
        if self.cur_oquestion == '' and self.cur_answer == '':
            pass
        else:
            self.history.append((self.cur_oquestion, self.cur_answer))
            self.crud.update_last(chat_id=self.chat_id)
            self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer,
                                     turn_number=len(self.history), is_last=1)


if __name__ == "__main__":
    # 数据库连接
    c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                       port=CHAT_DB_PORT, )
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
    vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
        store_path=FAISS_STORE_PATH,
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss)
    print(my_chat.chat("什么是低空经济"))
    my_chat.update_history()
    time.sleep(20)
    print(my_chat.cur_answer)