Commit 08419d48 by 陈正乐

gradio显示效果优化

parent 6b86858b
...@@ -62,3 +62,8 @@ prompt1 = """''' ...@@ -62,3 +62,8 @@ prompt1 = """'''
# NLP_BERT模型路径配置 # NLP_BERT模型路径配置
# ============================= # =============================
NLP_BERT_PATH = 'C:\\Users\\15663\\AI\\models\\nlp_bert_document-segmentation_chinese-base' NLP_BERT_PATH = 'C:\\Users\\15663\\AI\\models\\nlp_bert_document-segmentation_chinese-base'
# =============================
# ICON配置
# =============================
ICON_PATH = 'C:\\Users\\15663\\code\\dkjj-llm\\LAE\\LAE\\test\\icon'
...@@ -90,10 +90,13 @@ class QA: ...@@ -90,10 +90,13 @@ class QA:
task = asyncio.create_task( task = asyncio.create_task(
wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}, callbacks=[callback]), wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}, callbacks=[callback]),
callback.done)) callback.done))
history = self.get_history()
self.cur_answer = "" self.cur_answer = ""
history.append((self.cur_oquestion, self.cur_answer))
async for token in callback.aiter(): async for token in callback.aiter():
self.cur_answer = self.cur_answer + token self.cur_answer = self.cur_answer + token
yield f"{self.cur_answer}" history[-1] = (self.cur_oquestion, self.cur_answer)
yield history
await task await task
def get_history(self): def get_history(self):
......
...@@ -6,6 +6,7 @@ from src.llm.chatglm import ChatGLMSerLLM ...@@ -6,6 +6,7 @@ from src.llm.chatglm import ChatGLMSerLLM
from src.llm.ernie_with_sdk import ChatERNIESerLLM from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion from qianfan import ChatCompletion
import os import os
import asyncio
from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.config.consts import ( from src.config.consts import (
...@@ -24,7 +25,8 @@ from src.config.consts import ( ...@@ -24,7 +25,8 @@ from src.config.consts import (
VEC_DB_DBNAME, VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER, SIMILARITY_SHOW_NUMBER,
GR_PORT, GR_PORT,
GR_SERVER_NAME GR_SERVER_NAME,
ICON_PATH
) )
from src.server.qa import QA from src.server.qa import QA
...@@ -55,19 +57,37 @@ def main(): ...@@ -55,19 +57,37 @@ def main():
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2',
_faiss_db=vecstore_faiss) _faiss_db=vecstore_faiss)
def clear_q_a(): def clear(): # 清空输入框
return '', '' return ''
def show_history(): def show_history(): # 显示对话历史
return my_chat.get_history() return my_chat.get_history()
def stop_btn():
return gr.Button(interactive=False)
def restart_btn():
return gr.Button(interactive=True)
with gr.Blocks() as demo: with gr.Blocks() as demo:
with gr.Row(): chatbot = gr.Chatbot(bubble_full_width=False, avatar_images=(ICON_PATH+'\\user.png', ICON_PATH+"\\bot.png"),
inn = gr.Textbox(show_label=True, lines=10) value=show_history())
with gr.Row(): input_text = gr.Textbox(show_label=True, lines=3, label="文本输入")
qabtn = gr.Button("SUBMIT") sub_btn = gr.Button("提交")
out = gr.Textbox(show_label=True, lines=10)
qabtn.click(my_chat.async_chat, [inn], [out]) sub_btn.click(my_chat.async_chat, [input_text], [chatbot]
).then(
stop_btn, None, sub_btn
).then(
my_chat.update_history, None, None
).then(
show_history, None, chatbot
).then(
clear, None, [input_text]
).then(
restart_btn, None, sub_btn
)
demo.queue().launch(share=False, inbrowser=True, server_name=GR_SERVER_NAME, server_port=GR_PORT) demo.queue().launch(share=False, inbrowser=True, server_name=GR_SERVER_NAME, server_port=GR_PORT)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment