Commit 77a22b17 by 陈正乐

添加新建对话功能

parent 7e960d29
......@@ -152,6 +152,31 @@ class CRUD:
return self.db.fetchall()
def get_chats(self, account):
query = f'SELECT chat.chat_id,chat.info FROM chat JOIN c_user ON chat.user_id = c_user.user_id WHERE c_user.account = (%s);'
query = f'SELECT chat.chat_id,chat.info FROM chat JOIN c_user ON chat.user_id = c_user.user_id WHERE c_user.account = (%s) ORDER BY chat.create_time DESC;'
self.db.execute_args(query, (account,))
return self.db.fetchall()
def create_chat(self, user_id, info, deleted):
query = f'INSERT INTO chat(user_id, info, deleted) VALUES (%s,%s,%s) RETURNING chat_id'
self.db.execute_args(query, (user_id, info, deleted))
ans = self.db.fetchall()[0][0]
return ans
def get_uersid_from_account(self, account):
query = f'SELECT user_id FROM c_user WHERE account = (%s)'
self.db.execute_args(query, (account, ))
ans = self.db.fetchall()[0][0]
print(ans)
return ans
def get_chat_info(self, chat_id):
query = f'SELECT info FROM chat WHERE chat_id = (%s)'
self.db.execute_args(query, (chat_id,))
ans = self.db.fetchall()[0][0]
print(ans)
return ans
def set_info(self, chat_id, info):
query = f'UPDATE chat SET info = (%s) WHERE chat_id = (%s)'
self.db.execute_args(query, (info, chat_id))
......@@ -149,6 +149,17 @@ class QA:
self.chat_id = chat_id
self.history = self.crud.get_history(self.chat_id)
def create_chat(self, user_account):
user_id = self.crud.get_uersid_from_account(user_account)
self.chat_id = self.crud.create_chat(user_id, '\t\t', '0')
def set_info(self, question):
info = self.crud.get_chat_info(self.chat_id)
if info == '\t\t':
n_info = '这是一个info'
self.crud.set_info(self.chat_id, n_info)
if __name__ == "__main__":
# 数据库连接
......
# -*- coding: utf-8 -*-
import sys
sys.path.append('../')
import gradio as gr
from langchain.prompts import PromptTemplate
......@@ -38,8 +41,11 @@ prompt1 = """'''
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
global o_users, users_l, o_chats, chats_l
def main():
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
vecstore_faiss = VectorStore_FAISS(
......@@ -72,19 +78,29 @@ def main():
def get_users():
return my_chat.get_users()
def create_chat(user_account):
my_chat.create_chat(user_account)
def get_chats(user_account):
global o_users, users_l, o_chats, chats_l
print(chats_l)
o_chats = my_chat.get_chats(user_account)
chats_l = [item[0]+':'+item[1] for item in o_chats]
print(o_chats)
chats_l = [item[0] + ':' + item[1] for item in o_chats]
return gr.components.Radio(choices=chats_l, label="选择一个对话", value=chats_l[0], interactive=True)
def set_info(question):
my_chat.set_info(question)
def set_chat_id(chat_id_info):
chat_id = chat_id_info.split(':')[0]
my_chat.set_chat_id(chat_id)
global o_users, users_l, o_chats, chats_l
o_users = get_users()
users_l = [item[0] for item in o_users]
o_chats = my_chat.get_chats(users_l[0])
chats_l = [item[0]+':'+item[1] for item in o_chats]
chats_l = [item[0] + ':' + item[1] for item in o_chats]
set_chat_id(chats_l[0])
print(my_chat.chat_id)
......@@ -95,14 +111,23 @@ def main():
gr.HTML("""<h1 align="center">低空经济知识问答</h1>""")
with gr.Row():
with gr.Column(scale=2):
users = gr.components.Radio(choices=users_l, label="选择一个用户", value=users_l[0], interactive=True, visible=False)
chats = gr.components.Radio(choices=chats_l, label="选择一个对话", value=chats_l[0], interactive=True)
users = gr.components.Radio(choices=users_l, label="选择一个用户", value=users_l[0], interactive=True,
visible=False, show_label=False)
chats = gr.components.Radio(choices=chats_l, label="选择一个对话", value=chats_l[0], interactive=True,
show_label=False)
new_chat_btn = gr.Button("新建对话")
with gr.Column(scale=8):
chatbot = gr.Chatbot(bubble_full_width=False,
avatar_images=(ICON_PATH + '\\user.png', ICON_PATH + "\\bot.png"),
value=show_history(), height=400)
input_text = gr.Textbox(show_label=True, lines=3, label="文本输入")
sub_btn = gr.Button("提交")
avatar_images=(ICON_PATH + '\\user2.png', ICON_PATH + "\\bot2.png"),
value=show_history(), height=400, show_copy_button=True,
show_label=False, line_breaks=True)
with gr.Row():
input_text = gr.Textbox(show_label=False, lines=1, label="文本输入", scale=9)
sub_btn = gr.Button("提交", scale=1)
new_chat_btn.click(create_chat, [users], []).then(
get_chats, [users], [chats]
)
users.change(get_chats, [users], [chats]).then(
set_chat_id, [chats], None
......@@ -118,6 +143,10 @@ def main():
).then(
stop_btn, None, sub_btn
).then(
set_info, [input_text], []
).then(
get_chats, [users], [chats]
).then(
my_chat.update_history, None, None
).then(
show_history, None, chatbot
......@@ -127,7 +156,7 @@ def main():
restart_btn, None, sub_btn
)
demo.queue().launch(share=False, inbrowser=True, server_name='192.168.22.127', server_port=GR_PORT)
demo.queue().launch(share=False, inbrowser=True, server_name='192.168.22.80', server_port=GR_PORT)
if __name__ == "__main__":
......
......@@ -78,9 +78,9 @@ def test_faiss_load():
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
print(vecstore_faiss.join_document(vecstore_faiss.get_text_similarity("什么是低空飞行")))
print(vecstore_faiss.join_document(vecstore_faiss.get_text_similarity("我国什么时候全面开放低空领域")))
if __name__ == "__main__":
test_faiss_from_dir()
# test_faiss_from_dir()
test_faiss_load()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment