Commit 3c3636bd by 陈正乐

gradio添加用户展示,聊天展示

parent 08419d48
...@@ -4,7 +4,7 @@ TABLE_CHAT = """ ...@@ -4,7 +4,7 @@ TABLE_CHAT = """
DROP TABLE IF EXISTS "chat"; DROP TABLE IF EXISTS "chat";
CREATE TABLE chat ( CREATE TABLE chat (
chat_id varchar(1000) PRIMARY KEY, chat_id varchar(1000) PRIMARY KEY,
user_id int, user_id varchar(1000),
info text, info text,
create_time timestamp(6) DEFAULT current_timestamp, create_time timestamp(6) DEFAULT current_timestamp,
deleted int2 deleted int2
...@@ -145,3 +145,13 @@ class CRUD: ...@@ -145,3 +145,13 @@ class CRUD:
query = f'SELECT question FROM turn_qa WHERE chat_id = (%s) AND turn_number = 1' query = f'SELECT question FROM turn_qa WHERE chat_id = (%s) AND turn_number = 1'
self.db.execute_args(query, (_chat_id,)) self.db.execute_args(query, (_chat_id,))
return self.db.fetchone()[0] return self.db.fetchone()[0]
def get_users(self):
query = f'SELECT account FROM c_user'
self.db.execute(query)
return self.db.fetchall()
def get_chats(self, account):
query = f'SELECT chat.chat_id,chat.info FROM chat JOIN c_user ON chat.user_id = c_user.user_id WHERE c_user.account = (%s);'
self.db.execute_args(query, (account,))
return self.db.fetchall()
...@@ -39,16 +39,16 @@ PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=promp ...@@ -39,16 +39,16 @@ PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=promp
class QA: class QA:
def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id, _faiss_db): def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _faiss_db):
self.prompt = _prompt self.prompt = _prompt
self.base_llm = _base_llm self.base_llm = _base_llm
self.llm_kwargs = _llm_kwargs self.llm_kwargs = _llm_kwargs
self.prompt_kwargs = _prompt_kwargs self.prompt_kwargs = _prompt_kwargs
self.db = _db self.db = _db
self.chat_id = _chat_id self.chat_id = None
self.faiss_db = _faiss_db self.faiss_db = _faiss_db
self.crud = CRUD(self.db) self.crud = CRUD(self.db)
self.history = self.crud.get_history(self.chat_id) self.history = None
self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs) self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
self.cur_answer = "" self.cur_answer = ""
self.cur_question = "" self.cur_question = ""
...@@ -100,6 +100,7 @@ class QA: ...@@ -100,6 +100,7 @@ class QA:
await task await task
def get_history(self): def get_history(self):
self.history = self.crud.get_history(self.chat_id)
return self.history return self.history
def update_history(self): def update_history(self):
...@@ -111,6 +112,16 @@ class QA: ...@@ -111,6 +112,16 @@ class QA:
self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer, self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_oquestion, answer=self.cur_answer,
turn_number=len(self.history), is_last=1) turn_number=len(self.history), is_last=1)
def get_users(self):
return self.crud.get_users()
def get_chats(self, user_account):
return self.crud.get_chats(user_account)
def set_chat_id(self, chat_id):
self.chat_id = chat_id
self.history = self.crud.get_history(self.chat_id)
if __name__ == "__main__": if __name__ == "__main__":
# 数据库连接 # 数据库连接
...@@ -126,8 +137,6 @@ if __name__ == "__main__": ...@@ -126,8 +137,6 @@ if __name__ == "__main__":
"password": VEC_DB_PASSWORD}, "password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER, show_number=SIMILARITY_SHOW_NUMBER,
reset=False) reset=False)
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', _faiss_db=vecstore_faiss) my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
print(my_chat.chat("什么是低空经济")) my_chat.set_chat_id('1')
my_chat.update_history() print(my_chat.get_history())
time.sleep(20)
print(my_chat.cur_answer)
...@@ -17,7 +17,7 @@ def test(): ...@@ -17,7 +17,7 @@ def test():
port=CHAT_DB_PORT, ) port=CHAT_DB_PORT, )
print(c_db) print(c_db)
crud = CRUD(c_db) crud = CRUD(c_db)
# crud.create_table() crud.create_table()
# crud.insert_turn_qa("2", "wen4", "da1", 1, 0) # crud.insert_turn_qa("2", "wen4", "da1", 1, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 2, 0) # crud.insert_turn_qa("2", "wen4", "da1", 2, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 5, 0) # crud.insert_turn_qa("2", "wen4", "da1", 5, 0)
...@@ -27,7 +27,12 @@ def test(): ...@@ -27,7 +27,12 @@ def test():
# crud.insert_turn_qa("2", "wen4", "da1", 8, 0) # crud.insert_turn_qa("2", "wen4", "da1", 8, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 7, 0) # crud.insert_turn_qa("2", "wen4", "da1", 7, 0)
# crud.insert_turn_qa("2", "wen4", "da1", 9, 0) # crud.insert_turn_qa("2", "wen4", "da1", 9, 0)
crud.insert_c_user('zhangs','111111') crud.insert_c_user('zhangs', '111111')
crud.insert_chat('1', '这是chat_id为1的问答info', 0)
crud.insert_c_user('lis', '111111')
crud.insert_chat('1', '这是chat_id为2的问答info', 0)
crud.insert_chat('2', '这是chat_id为3的问答info', 0)
crud.insert_chat('2', '这是chat_id为4的问答info', 0)
print(crud.get_history('2')) print(crud.get_history('2'))
......
...@@ -54,7 +54,7 @@ def main(): ...@@ -54,7 +54,7 @@ def main():
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
# base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088') # base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088')
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db,
_faiss_db=vecstore_faiss) _faiss_db=vecstore_faiss)
def clear(): # 清空输入框 def clear(): # 清空输入框
...@@ -69,11 +69,50 @@ def main(): ...@@ -69,11 +69,50 @@ def main():
def restart_btn(): def restart_btn():
return gr.Button(interactive=True) return gr.Button(interactive=True)
def get_users():
return my_chat.get_users()
def get_chats(user_account):
o_chats = my_chat.get_chats(user_account)
chats_l = [item[0]+':'+item[1] for item in o_chats]
return gr.components.Radio(choices=chats_l, label="选择一个对话", value=chats_l[0], interactive=True)
def set_chat_id(chat_id_info):
chat_id = chat_id_info.split(':')[0]
my_chat.set_chat_id(chat_id)
o_users = get_users()
users_l = [item[0] for item in o_users]
o_chats = my_chat.get_chats(users_l[0])
chats_l = [item[0]+':'+item[1] for item in o_chats]
set_chat_id(chats_l[0])
print(my_chat.chat_id)
print(type(my_chat.chat_id))
a = my_chat.get_history()
print(a)
with gr.Blocks() as demo: with gr.Blocks() as demo:
chatbot = gr.Chatbot(bubble_full_width=False, avatar_images=(ICON_PATH+'\\user.png', ICON_PATH+"\\bot.png"), gr.HTML("""<h1 align="center">低空经济知识问答</h1>""")
value=show_history()) with gr.Row():
input_text = gr.Textbox(show_label=True, lines=3, label="文本输入") with gr.Column(scale=2):
sub_btn = gr.Button("提交") users = gr.components.Radio(choices=users_l, label="选择一个用户", value=users_l[0], interactive=True)
chats = gr.components.Radio(choices=chats_l, label="选择一个对话", value=chats_l[0], interactive=True)
with gr.Column(scale=8):
chatbot = gr.Chatbot(bubble_full_width=False,
avatar_images=(ICON_PATH + '\\user.png', ICON_PATH + "\\bot.png"),
value=show_history())
input_text = gr.Textbox(show_label=True, lines=3, label="文本输入")
sub_btn = gr.Button("提交")
users.change(get_chats, [users], [chats]).then(
set_chat_id, [chats], None
).then(
show_history, None, chatbot
)
chats.change(set_chat_id, [chats], None).then(
show_history, None, chatbot
)
sub_btn.click(my_chat.async_chat, [input_text], [chatbot] sub_btn.click(my_chat.async_chat, [input_text], [chatbot]
).then( ).then(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment