Commit 8ddfdec2 by 陈正乐

实现基于面向对象的模型QA服务

parent 8efec361
......@@ -15,13 +15,7 @@ COMMENT ON TABLE "c_user" IS '用户表';
class CUser:
def __init__(self, db: UPostgresDB) -> None:
self.db = db
def insert(self, value):
query = f"INSERT INTO c_user(user_id, account, password) VALUES (%s,%s,%s)"
self.db.execute_args(query, (value[0], value[1], value[2]))
def create_table(self):
query = TABLE_USER
self.db.execute(query)
def __init__(self, _user_id, _account, _password) -> None:
self.user_id = _user_id
self.account = _account
self.password = _password
......@@ -19,15 +19,9 @@ COMMENT ON TABLE "chat" IS '会话信息表';
class Chat:
def __init__(self, db: UPostgresDB) -> None:
self.db = db
# 插入数据
def insert(self, value):
query = f"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (%s,%s,%s,%s)"
self.db.execute_args(query, (value[0], value[1], value[2], value[3]))
# 创建表
def create_table(self):
query = TABLE_CHAT
self.db.execute(query)
def __init__(self, _chat_id, _user_id, _info, _create_time, _deleted) -> None:
self.chat_id = _chat_id
self.user_id = _user_id
self.info = _info
self.create_time = _create_time
self.deleted = _deleted
from .c_db import UPostgresDB
TABLE_CHAT = """
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
chat_id varchar(1000) PRIMARY KEY,
user_id int,
info text,
create_time timestamp(6) DEFAULT current_timestamp,
deleted int2
);
COMMENT ON COLUMN "chat"."chat_id" IS '会话id';
COMMENT ON COLUMN "chat"."user_id" IS '会话创建用户id';
COMMENT ON COLUMN "chat"."info" IS '会话简介';
COMMENT ON COLUMN "chat"."create_time" IS '会话创建时间,默认为当前时间';
COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
DROP SEQUENCE IF EXISTS "chat_seq";
CREATE SEQUENCE "chat_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE chat ALTER COLUMN chat_id SET DEFAULT nextval('chat_seq'::regclass);
"""
TABLE_TURN_QA = """
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
turn_id varchar(1000) PRIMARY KEY,
chat_id varchar(1000),
question text,
answer text,
create_time timestamp(6) DEFAULT current_timestamp,
turn_number int,
is_last int2
);
COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
DROP SEQUENCE IF EXISTS "turn_qa_seq";
CREATE SEQUENCE "turn_qa_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE turn_qa ALTER COLUMN turn_id SET DEFAULT nextval('turn_qa_seq'::regclass);
"""
TABLE_USER = """
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
user_id varchar(1000) PRIMARY KEY,
account varchar(20) NOT NULL,
password varchar(50) NOT NULL
);
COMMENT ON COLUMN "c_user"."user_id" IS '用户id';
COMMENT ON COLUMN "c_user"."account" IS '用户帐户';
COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
DROP SEQUENCE IF EXISTS "c_user_seq";
CREATE SEQUENCE "c_user_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE c_user ALTER COLUMN user_id SET DEFAULT nextval('c_user_seq'::regclass);
"""
class CRUD:
def __init__(self, _db: UPostgresDB):
self.db = _db
def create_table(self):
self.db.execute(TABLE_CHAT)
self.db.execute(TABLE_TURN_QA)
self.db.execute(TABLE_USER)
def get_history(self, _chat_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall()
return ans
def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last):
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last))
def insert_c_user(self, account, password):
query = f'INSERT INTO c_user(account, password) VALUES (%s,%s)'
self.db.execute_args(query, (account, password))
def insert_chat(self, user_id, info, deleted):
query = f'INSERT INTO chat(user_id, info, deleted) VALUES (%s,%s,%s)'
self.db.execute_args(query, (user_id, info, deleted))
def update_last(self, chat_id):
query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND is_last = 1'
self.db.execute_args(query, (chat_id,))
from .c_db import UPostgresDB
TABLE_CHAT = """
TABLE_TURN_QA = """
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
turn_id varchar(1000) PRIMARY KEY,
......@@ -23,15 +21,11 @@ COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
class TurnQa:
def __init__(self, db: UPostgresDB) -> None:
self.db = db
# 插入数据
def insert(self, value):
query = f"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s,%s)"
self.db.execute_args(query, (value[0], value[1], value[2], value[3], value[4], value[5]))
# 创建表
def create_table(self):
query = TABLE_CHAT
self.db.execute(query)
def __init__(self, _turn_id, _chat_id, _question, _answer, _create_time, _turn_number, _is_last) -> None:
self.turn_id = _turn_id
self.chat_id = _chat_id
self.question = _question
self.answer = _answer
self.create_time = _create_time
self.turn_number = _turn_number
self.is_last = _is_last
\ No newline at end of file
......@@ -63,11 +63,13 @@ class RE_FAISS(FAISS):
@staticmethod
def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
deduplicated_dict = OrderedDict()
print("--------------oedereddict type--------------", type(deduplicated_dict))
for doc, scores in tuple_input:
page_content = doc.page_content
metadata = doc.metadata
if page_content not in deduplicated_dict:
deduplicated_dict[page_content] = (metadata, scores)
print("--------------------------du--------------------------\n", deduplicated_dict)
deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
deduplicated_dict.items()]
return deduplicated_documents
......
# -*- coding: utf-8 -*-
import sys
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from typing import Awaitable
import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD
)
sys.path.append("../..")
prompt1 = """'''
......@@ -19,22 +26,33 @@ prompt1 = """'''
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
def chat(context, question):
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
erniellm = LLMChain(llm=base_llm, prompt=PROMPT1, llm_kwargs={"temperature": 0.9})
if not context and not question:
return ""
result = erniellm.run({"context": context, "question": question})
return result
class QA:
def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id):
self.prompt = _prompt
self.base_llm = _base_llm
self.llm_kwargs = _llm_kwargs
self.prompt_kwargs = _prompt_kwargs
self.db = _db
self.chat_id = _chat_id
self.crud = CRUD(self.db)
self.history = self.crud.get_history(self.chat_id)
self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
self.cur_answer = ""
self.cur_question = ""
# 一次性直接给出所有的答案
def chat(self, *args):
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
self.cur_answer = ""
if not args:
return ""
self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, args)})
return self.cur_answer
async def async_chat_stc(context, question):
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
erniellm = LLMChain(llm=base_llm, prompt=PROMPT1, llm_kwargs={"temperature": 0.9})
# 异步输出,逐渐输出答案
async def async_chat_stc(self, *args):
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
callback = AsyncIteratorCallbackHandler()
async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
await fn
......@@ -44,17 +62,32 @@ async def async_chat_stc(context, question):
print(f"Caught exception: {e}")
finally:
event.set()
task = asyncio.create_task(
wrap_done(erniellm.arun({"context": context, "question": question}, callbacks=[callback]), callback.done))
print("*" * 20)
text = ""
wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, args)}, callbacks=[callback]),
callback.done))
self.cur_answer = ""
async for token in callback.aiter():
text = text + token
yield f"{text}"
self.cur_answer = self.cur_answer + token
yield f"{self.cur_answer}"
await task
def get_history(self):
return self.history
def updata_history(self):
self.history.append((self.cur_question, self.cur_answer))
self.crud.update_last(chat_id=self.chat_id)
self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_question, answer=self.cur_answer,
turn_number=len(self.history), is_last=1)
if __name__ == "__main__":
print("main函数begin")
print(chat("当别人想你说你好的时候,你也应该说你好", "你好"))
# 数据库连接
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2')
print(my_chat.chat("当别人想你说你好的时候,你也应该说你好", "你好"))
my_chat.updata_history()
import sys
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.chat_table import Chat
from src.pgdb.chat.c_user_table import CUser
from src.pgdb.chat.turn_qa_table import TurnQa
from src.pgdb.chat.crud import CRUD
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD
)
sys.path.append("../")
"""测试会话相关数据可的连接"""
def test():
c_db = UPostgresDB(host="localhost", database="laechat", user="postgres", password="chenzl", port=5432)
chat = Chat(db=c_db)
c_user = CUser(db=c_db)
turn_qa = TurnQa(db=c_db)
chat.create_table()
c_user.create_table()
turn_qa.create_table()
# chat_id, user_id, info, deleted
chat.insert(["3333", "1111", "没有info", 0])
# user_id, account, password
c_user.insert(["111", "zhangsan", "111111"])
# turn_id, chat_id, question, answer, turn_number, is_last
turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0])
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
crud = CRUD(c_db)
crud.create_table()
crud.insert_turn_qa("2", "wen4", "da1", 1, 0)
crud.insert_turn_qa("2", "wen4", "da1", 2, 0)
crud.insert_turn_qa("2", "wen4", "da1", 5, 0)
crud.insert_turn_qa("2", "wen4", "da1", 4, 0)
crud.insert_turn_qa("2", "wen4", "da1", 3, 0)
crud.insert_turn_qa("2", "wen4", "da1", 6, 0)
crud.insert_turn_qa("2", "wen4", "da1", 8, 0)
crud.insert_turn_qa("2", "wen4", "da1", 7, 0)
crud.insert_turn_qa("2", "wen4", "da1", 9, 0)
print(crud.get_history('2'))
if __name__ == "__main__":
......
import gradio as gr
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">辅助生成知识库</h1>""")
# with gr.Row():
# input_text = gr.Textbox(show_label=True, placeholder="输入需要处理的文档...", lines=10)
with gr.Row():
input_text = gr.Textbox(show_label=True, placeholder="输入需要处理的文档...", lines=10, scale=9)
model_selector = gr.Dropdown(choices=["ernie", "chatglm3"], label="请选择一个模型", scale=1, min_width=50,
value="chatglm3")
with gr.Row():
num_selector = gr.Slider(minimum=0, maximum=10, value=5, label="请选择问题数量", step=1)
with gr.Row():
qaBtn = gr.Button("QA问答对生成")
demo.queue().launch(share=False, inbrowser=True,server_name="192.168.100.76",server_port=8888)
\ No newline at end of file
from functools import reduce
def add_three(x,y):
return x + y
li = [1,2,3,5]
reduce(add_three, li)#=> 11
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment