web_contract.py 4.31 KB
Newer Older
陈正乐 committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
import gradio as gr
from flask import Flask, request, jsonify
import torch

from contract.extraction import ElementsExtractor
from llm.chatglm import ChatGLMLocLLM
from llm.ernie import ErnieLLM
from llm.baichuan import BaichuanLLM
from loader.load import load_file,load

from flask.cli import load_dotenv
load_dotenv()


# Load the model
llms = ["ChatGLM","ChatGLM2","Ernie"]
llm = ChatGLMLocLLM(model_name="../../models/chatglm2-6b")

# llm = ErnieLLM()

extractor=ElementsExtractor(llm=llm)
elements = ["合同号","买方","卖方","合同金额","合同签订日期","装运标记","甲方","乙方","甲方地址","乙方地址"]
# max_length=1000
# Define the Gradio interface
def contract(file,elements,max_length):
    print(llm.model_name)
    print(file,elements)
    if file is None:
        return "Error: could not load file"
    docs = load(file.name)
    if docs is None:
        return "Error: could not load file"
    print(len(docs))
    content = []
    content_len = 0
    values={}
    for d in docs:
        if content_len+len(d.page_content)>max_length:
            doc = "\n".join(content)
            eles = extractor.extract(doc,elements)
            for e in eles:
                try:
                    k,v = e.split(":",maxsplit=1)
                    k = k.strip()
                    v = v.strip()
                    if v is not None and v != "" and v!="未知" and k in elements:
                        values[k]=v+","+values[k] if k in values else v
                except Exception as exp:
                    print(exp)
                    print(e)
                    continue
            print("\n".join([f"{k}:{v}" for k,v in values.items()]))
            content=[d.page_content]
            content_len=len(d.page_content)
        else:
            content.append(d.page_content)
            content_len+=len(d.page_content)
        
    return "\n".join([f"{k}:{v}" for k,v in values.items()])

    
def change_llm_type(llm_type):
    print("change_llm_type",llm_type)
    global llm,extractor
    
    del llm
    llm=ErnieLLM()
    torch.cuda.empty_cache() 
    if llm_type=="ChatGLM":
        llm = ChatGLMLocLLM(model_name="../../models/chatglm-6b")
    elif llm_type=="ChatGLM2":
        llm = ChatGLMLocLLM(model_name="../../models/chatglm2-6b")
    elif llm_type=="Ernie":
        llm = ErnieLLM()
    elif llm_type=="baichuan-13b":
        llm = BaichuanLLM(model_name="../../models/Baichuan-13B-Chat",quantization_bit=8)
    else:
        llm = ErnieLLM()

    if llm is not None:
        extractor=ElementsExtractor(llm=llm)

    return llm_type


def add_element(ele_new):
    print("add_element",elements,ele_new)
    elements.append(ele_new)
    return {ele_group:gr.update(choices=elements),
            ele_new_box:gr.update(value="")}

def reset():
    output.value=""
    file.value=None

with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">商业合同要素提取</h1>""")

    with gr.Row():
        with gr.Column(scale=4):
            with gr.Row():
                file = gr.File(label="上传文件")
            with gr.Row():
                submit_btn=gr.Button("开始提取", type="submit")
                # reset_btn=gr.Button("重置", type="reset")
                # reset_btn.click(reset)   

            with gr.Row():
                output=gr.Textbox(label="提取结果", type="text", lines=20)

        with gr.Column(scale=1):
            with gr.Row():
                max_length = gr.Slider(1000, 10000, value=5000, step=1000, label="单次提取使用的文本长度", interactive=True)
            with gr.Row():
                llm_type = gr.Radio(llms, label="语言模型类型", value="ChatGLM2", interactive=True)
                llm_type.change(change_llm_type, inputs=[llm_type],outputs=[llm_type])
            with gr.Row():
                ele_group = gr.CheckboxGroup(choices=elements, label="需要提取的元素", value=elements, interactive=True)
                with gr.Row():
                    ele_new_box = gr.Textbox(label="新增元素", type="text", lines=1)
                    ele_new_btn = gr.Button("新增", type="submit")
                    ele_new_btn.click(add_element,inputs=[ele_new_box],outputs=[ele_group,ele_new_box])
                            
        submit_btn.click(contract,inputs=[file,ele_group,max_length],outputs=output)
            
demo.queue().launch(share=True)