import gradio as gr
from flask import Flask, request, jsonify
import torch
from qa.generator import QAGenerator
from llm.chatglm import ChatGLMLocLLM
from llm.ernie import ErnieLLM
from loader.load import load_file

from flask.cli import load_dotenv
load_dotenv()

# Load the model
# llm = ChatGLMLocLLM(model_name="../../models/chatglm2-6b")
llm = ErnieLLM()
qa = QAGenerator(llm=llm)

# Define the Gradio interface
def qa_interface(file,step,question_numbers):
    print(file,step,question_numbers)

    # Load the file
    docs = load_file(file.name)
    if docs is None:
        return "Error: could not load file"
    
    for i in range(0,len(docs),step):
        print(i)
        content = "\n".join([d.page_content for d in docs[i:i+step]])
        knowledge=content
        lines = qa.generate_questions(knowledge=knowledge,question_number=question_numbers)
        answers = []
        for line in lines:
            answers.append(line)
            answer=qa.generate_answer(knowledge=knowledge,question=line)
            answers.append(answer)
            answers.append("----")
        yield "\n".join(answers)

def change_llm_type(llm_type):
    print("change_llm_type",llm_type)
    global llm,qa
    
    del llm
    torch.cuda.empty_cache() 
    
    if llm_type=="ChatGLM":
        llm = ChatGLMLocLLM(model_name="../../models/chatglm-6b")
    elif llm_type=="ChatGLM2":
        llm = ChatGLMLocLLM(model_name="../../models/chatglm2-6b")
    elif llm_type=="Ernie":
        llm = ErnieLLM()
    else:
        llm = ErnieLLM()

    if llm is not None:
        qa = QAGenerator(llm=llm)   
    return llm_type


def reset():
    output.value=""
    file.value=None

with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">邮储新一代知识库-QA提取</h1>""")

    with gr.Row():
        with gr.Column(scale=4):
            with gr.Row():
                file = gr.File(label="上传文件",type="file")
            with gr.Row():
                submit_btn=gr.Button("开始提取", type="submit")
                # reset_btn=gr.Button("重置", type="reset")
                # reset_btn.click(reset)   

            with gr.Row():
                output=gr.Textbox(label="提取结果", type="text", lines=10, readonly=True)

        with gr.Column(scale=1):
            with gr.Row():
                step = gr.Slider(10, 1000, value=100, step=10, label="单次提取使用的知识句子数量", interactive=True)
                question_numbers = gr.Slider(0, 10, value=5, step=1, label="单次提取问题数量", interactive=True)
            with gr.Row():
                llm_type = gr.Radio(["ChatGLM","ChatGLM2","Ernie"], label="语言模型类型", value="Ernie", interactive=True)
                llm_type.change(change_llm_type, inputs=[llm_type],outputs=[llm_type])

        submit_btn.click(qa_interface,inputs=[file,step,question_numbers],outputs=output)
            


demo.queue().launch(share=True)