gradio_test.py 5.1 KB
Newer Older
1
# -*- coding: utf-8 -*-
陈正乐 committed
2 3 4
import sys

sys.path.append('../')
5 6
import gradio as gr
from langchain.prompts import PromptTemplate
7 8

from src.llm.chatglm import ChatGLMSerLLM
9 10
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
11
import os
陈正乐 committed
12
import asyncio
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
    CHAT_DB_PASSWORD,
    EMBEEDING_MODEL_PATH,
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
陈正乐 committed
29 30
    SIMILARITY_SHOW_NUMBER,
    GR_PORT,
陈正乐 committed
31 32
    GR_SERVER_NAME,
    ICON_PATH
33 34 35 36 37 38 39 40 41 42 43
)
from src.server.qa import QA

prompt1 = """'''
{context}
'''
请你根据上述已知资料回答下面的问题,问题如下:
{question}"""

PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)

陈正乐 committed
44 45

def main():
46 47 48 49 50 51 52 53 54 55
    c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                       port=CHAT_DB_PORT, )
    vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
        store_path=FAISS_STORE_PATH,
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
陈正乐 committed
56 57 58
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
    # base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088')
59

60
    my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db,
陈正乐 committed
61
                 _faiss_db=vecstore_faiss)
62

陈正乐 committed
63 64
    def clear():  # 清空输入框
        return ''
65

陈正乐 committed
66
    def show_history():  # 显示对话历史
67 68
        return my_chat.get_history()

陈正乐 committed
69 70 71 72 73 74
    def stop_btn():
        return gr.Button(interactive=False)

    def restart_btn():
        return gr.Button(interactive=True)

75
    def get_users():
陈正乐 committed
76 77 78
        o_users = my_chat.get_users()
        users_l = [item[0] for item in o_users]
        return gr.components.Radio(choices=users_l, label="选择一个用户", value=users_l[0], interactive=True), users_l[0]
79

陈正乐 committed
80 81 82
    def create_chat(user_account):
        my_chat.create_chat(user_account)

83 84
    def get_chats(user_account):
        o_chats = my_chat.get_chats(user_account)
陈正乐 committed
85
        chats_l = [item[0] + ':' + item[1] for item in o_chats]
86 87
        return gr.components.Radio(choices=chats_l, label="选择一个对话", value=chats_l[0], interactive=True)

陈正乐 committed
88 89 90
    def set_info(question):
        my_chat.set_info(question)

91 92 93 94
    def set_chat_id(chat_id_info):
        chat_id = chat_id_info.split(':')[0]
        my_chat.set_chat_id(chat_id)

陈正乐 committed
95 96 97 98 99 100
    def load():
        l_users, t_user = get_users()
        l_chats = get_chats(t_user)
        return l_users, l_chats

    with gr.Blocks(css='index(1).css') as demo:
101 102 103
        gr.HTML("""<h1 align="center">低空经济知识问答</h1>""")
        with gr.Row():
            with gr.Column(scale=2):
陈正乐 committed
104
                users = gr.components.Radio(choices=[], label="选择一个用户", interactive=True,
陈正乐 committed
105
                                            visible=False, show_label=False)
陈正乐 committed
106
                chats = gr.components.Radio(choices=[], label="选择一个对话", interactive=True,
陈正乐 committed
107 108
                                            show_label=False)
                new_chat_btn = gr.Button("新建对话")
109 110
            with gr.Column(scale=8):
                chatbot = gr.Chatbot(bubble_full_width=False,
陈正乐 committed
111 112 113 114
                                     avatar_images=(ICON_PATH + '\\user2.png', ICON_PATH + "\\bot2.png"),
                                     value=show_history(), height=400, show_copy_button=True,
                                     show_label=False, line_breaks=True)
                with gr.Row():
陈正乐 committed
115
                    input_text = gr.Textbox(show_label=False, lines=1, label="文本输入", scale=9, container=False)
陈正乐 committed
116 117
                    sub_btn = gr.Button("提交", scale=1)

陈正乐 committed
118 119
        demo.load(load, [], [users, chats])

陈正乐 committed
120 121 122
        new_chat_btn.click(create_chat, [users], []).then(
            get_chats, [users], [chats]
        )
123 124 125 126 127 128 129 130 131 132

        users.change(get_chats, [users], [chats]).then(
            set_chat_id, [chats], None
        ).then(
            show_history, None, chatbot
        )

        chats.change(set_chat_id, [chats], None).then(
            show_history, None, chatbot
        )
陈正乐 committed
133 134 135 136

        sub_btn.click(my_chat.async_chat, [input_text], [chatbot]
                      ).then(
            stop_btn, None, sub_btn
陈正乐 committed
137 138 139 140
        ).then(
            set_info, [input_text], []
        ).then(
            get_chats, [users], [chats]
陈正乐 committed
141 142 143 144 145 146 147 148 149 150
        ).then(
            my_chat.update_history, None, None
        ).then(
            show_history, None, chatbot
        ).then(
            clear, None, [input_text]
        ).then(
            restart_btn, None, sub_btn
        )

陈正乐 committed
151
    demo.queue().launch(share=False, inbrowser=True, server_name='192.168.22.80', server_port=GR_PORT)
152 153 154


if __name__ == "__main__":
陈正乐 committed
155
    main()