Commit 8ddfdec2 by 陈正乐

实现基于面向对象的模型QA服务

parent 8efec361
...@@ -15,13 +15,7 @@ COMMENT ON TABLE "c_user" IS '用户表'; ...@@ -15,13 +15,7 @@ COMMENT ON TABLE "c_user" IS '用户表';
class CUser: class CUser:
def __init__(self, db: UPostgresDB) -> None: def __init__(self, _user_id, _account, _password) -> None:
self.db = db self.user_id = _user_id
self.account = _account
def insert(self, value): self.password = _password
query = f"INSERT INTO c_user(user_id, account, password) VALUES (%s,%s,%s)"
self.db.execute_args(query, (value[0], value[1], value[2]))
def create_table(self):
query = TABLE_USER
self.db.execute(query)
...@@ -19,15 +19,9 @@ COMMENT ON TABLE "chat" IS '会话信息表'; ...@@ -19,15 +19,9 @@ COMMENT ON TABLE "chat" IS '会话信息表';
class Chat: class Chat:
def __init__(self, db: UPostgresDB) -> None: def __init__(self, _chat_id, _user_id, _info, _create_time, _deleted) -> None:
self.db = db self.chat_id = _chat_id
self.user_id = _user_id
# 插入数据 self.info = _info
def insert(self, value): self.create_time = _create_time
query = f"INSERT INTO chat(chat_id, user_id, info, deleted) VALUES (%s,%s,%s,%s)" self.deleted = _deleted
self.db.execute_args(query, (value[0], value[1], value[2], value[3]))
# 创建表
def create_table(self):
query = TABLE_CHAT
self.db.execute(query)
from .c_db import UPostgresDB
TABLE_CHAT = """
DROP TABLE IF EXISTS "chat";
CREATE TABLE chat (
chat_id varchar(1000) PRIMARY KEY,
user_id int,
info text,
create_time timestamp(6) DEFAULT current_timestamp,
deleted int2
);
COMMENT ON COLUMN "chat"."chat_id" IS '会话id';
COMMENT ON COLUMN "chat"."user_id" IS '会话创建用户id';
COMMENT ON COLUMN "chat"."info" IS '会话简介';
COMMENT ON COLUMN "chat"."create_time" IS '会话创建时间,默认为当前时间';
COMMENT ON COLUMN "chat"."deleted" IS '是否删除:0=否,1=是';
COMMENT ON TABLE "chat" IS '会话信息表';
DROP SEQUENCE IF EXISTS "chat_seq";
CREATE SEQUENCE "chat_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE chat ALTER COLUMN chat_id SET DEFAULT nextval('chat_seq'::regclass);
"""
TABLE_TURN_QA = """
DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa (
turn_id varchar(1000) PRIMARY KEY,
chat_id varchar(1000),
question text,
answer text,
create_time timestamp(6) DEFAULT current_timestamp,
turn_number int,
is_last int2
);
COMMENT ON COLUMN "turn_qa"."turn_id" IS '会话轮次id';
COMMENT ON COLUMN "turn_qa"."chat_id" IS '会话id';
COMMENT ON COLUMN "turn_qa"."question" IS '该轮会话问题';
COMMENT ON COLUMN "turn_qa"."answer" IS '该轮会话答案';
COMMENT ON COLUMN "turn_qa"."create_time" IS '该轮会话创建时间,默认为当前时间';
COMMENT ON COLUMN "turn_qa"."turn_number" IS '会话轮数';
COMMENT ON COLUMN "turn_qa"."is_last" IS '是否为最后一轮对话:0=否,1=是';
COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
DROP SEQUENCE IF EXISTS "turn_qa_seq";
CREATE SEQUENCE "turn_qa_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE turn_qa ALTER COLUMN turn_id SET DEFAULT nextval('turn_qa_seq'::regclass);
"""
TABLE_USER = """
DROP TABLE IF EXISTS "c_user";
CREATE TABLE c_user (
user_id varchar(1000) PRIMARY KEY,
account varchar(20) NOT NULL,
password varchar(50) NOT NULL
);
COMMENT ON COLUMN "c_user"."user_id" IS '用户id';
COMMENT ON COLUMN "c_user"."account" IS '用户帐户';
COMMENT ON COLUMN "c_user"."password" IS '用户密码';
COMMENT ON TABLE "c_user" IS '用户表';
DROP SEQUENCE IF EXISTS "c_user_seq";
CREATE SEQUENCE "c_user_seq"
INCREMENT 1
MINVALUE 1
MAXVALUE 2147483647
START 1
CACHE 1;
ALTER TABLE c_user ALTER COLUMN user_id SET DEFAULT nextval('c_user_seq'::regclass);
"""
class CRUD:
def __init__(self, _db: UPostgresDB):
self.db = _db
def create_table(self):
self.db.execute(TABLE_CHAT)
self.db.execute(TABLE_TURN_QA)
self.db.execute(TABLE_USER)
def get_history(self, _chat_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall()
return ans
def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last):
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last))
def insert_c_user(self, account, password):
query = f'INSERT INTO c_user(account, password) VALUES (%s,%s)'
self.db.execute_args(query, (account, password))
def insert_chat(self, user_id, info, deleted):
query = f'INSERT INTO chat(user_id, info, deleted) VALUES (%s,%s,%s)'
self.db.execute_args(query, (user_id, info, deleted))
def update_last(self, chat_id):
query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND is_last = 1'
self.db.execute_args(query, (chat_id,))
from .c_db import UPostgresDB TABLE_TURN_QA = """
TABLE_CHAT = """
DROP TABLE IF EXISTS "turn_qa"; DROP TABLE IF EXISTS "turn_qa";
CREATE TABLE turn_qa ( CREATE TABLE turn_qa (
turn_id varchar(1000) PRIMARY KEY, turn_id varchar(1000) PRIMARY KEY,
...@@ -23,15 +21,11 @@ COMMENT ON TABLE "turn_qa" IS '会话轮次信息表'; ...@@ -23,15 +21,11 @@ COMMENT ON TABLE "turn_qa" IS '会话轮次信息表';
class TurnQa: class TurnQa:
def __init__(self, db: UPostgresDB) -> None: def __init__(self, _turn_id, _chat_id, _question, _answer, _create_time, _turn_number, _is_last) -> None:
self.db = db self.turn_id = _turn_id
self.chat_id = _chat_id
# 插入数据 self.question = _question
def insert(self, value): self.answer = _answer
query = f"INSERT INTO turn_qa(turn_id, chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s,%s)" self.create_time = _create_time
self.db.execute_args(query, (value[0], value[1], value[2], value[3], value[4], value[5])) self.turn_number = _turn_number
self.is_last = _is_last
# 创建表 \ No newline at end of file
def create_table(self):
query = TABLE_CHAT
self.db.execute(query)
...@@ -63,11 +63,13 @@ class RE_FAISS(FAISS): ...@@ -63,11 +63,13 @@ class RE_FAISS(FAISS):
@staticmethod @staticmethod
def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]: def _tuple_deduplication(tuple_input: List[Tuple[Document, float]]) -> List[Tuple[Document, float]]:
deduplicated_dict = OrderedDict() deduplicated_dict = OrderedDict()
print("--------------oedereddict type--------------", type(deduplicated_dict))
for doc, scores in tuple_input: for doc, scores in tuple_input:
page_content = doc.page_content page_content = doc.page_content
metadata = doc.metadata metadata = doc.metadata
if page_content not in deduplicated_dict: if page_content not in deduplicated_dict:
deduplicated_dict[page_content] = (metadata, scores) deduplicated_dict[page_content] = (metadata, scores)
print("--------------------------du--------------------------\n", deduplicated_dict)
deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in deduplicated_documents = [(Document(page_content=key, metadata=value[0]), value[1]) for key, value in
deduplicated_dict.items()] deduplicated_dict.items()]
return deduplicated_documents return deduplicated_documents
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys import sys
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from typing import Awaitable from typing import Awaitable
import asyncio import asyncio
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from src.llm.ernie_with_sdk import ChatERNIESerLLM from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion from qianfan import ChatCompletion
from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.crud import CRUD
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD
)
sys.path.append("../..") sys.path.append("../..")
prompt1 = """''' prompt1 = """'''
...@@ -19,42 +26,68 @@ prompt1 = """''' ...@@ -19,42 +26,68 @@ prompt1 = """'''
PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1) PROMPT1 = PromptTemplate(input_variables=["context", "question"], template=prompt1)
def chat(context, question): class QA:
base_llm = ChatERNIESerLLM( def __init__(self, _prompt, _base_llm, _llm_kwargs, _prompt_kwargs, _db, _chat_id):
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) self.prompt = _prompt
erniellm = LLMChain(llm=base_llm, prompt=PROMPT1, llm_kwargs={"temperature": 0.9}) self.base_llm = _base_llm
if not context and not question: self.llm_kwargs = _llm_kwargs
return "" self.prompt_kwargs = _prompt_kwargs
result = erniellm.run({"context": context, "question": question}) self.db = _db
return result self.chat_id = _chat_id
self.crud = CRUD(self.db)
self.history = self.crud.get_history(self.chat_id)
self.llm = LLMChain(llm=self.base_llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
self.cur_answer = ""
self.cur_question = ""
# 一次性直接给出所有的答案
def chat(self, *args):
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
self.cur_answer = ""
if not args:
return ""
self.cur_answer = self.llm.run({k: v for k, v in zip(self.prompt_kwargs, args)})
return self.cur_answer
async def async_chat_stc(context, question): # 异步输出,逐渐输出答案
base_llm = ChatERNIESerLLM( async def async_chat_stc(self, *args):
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, args)})
erniellm = LLMChain(llm=base_llm, prompt=PROMPT1, llm_kwargs={"temperature": 0.9}) callback = AsyncIteratorCallbackHandler()
callback = AsyncIteratorCallbackHandler() async def wrap_done(fn: Awaitable, event: asyncio.Event):
try:
async def wrap_done(fn: Awaitable, event: asyncio.Event): await fn
try: except Exception as e:
await fn import traceback
except Exception as e: traceback.print_exc()
import traceback print(f"Caught exception: {e}")
traceback.print_exc() finally:
print(f"Caught exception: {e}") event.set()
finally: task = asyncio.create_task(
event.set() wrap_done(self.llm.arun({k: v for k, v in zip(self.prompt_kwargs, args)}, callbacks=[callback]),
callback.done))
task = asyncio.create_task( self.cur_answer = ""
wrap_done(erniellm.arun({"context": context, "question": question}, callbacks=[callback]), callback.done)) async for token in callback.aiter():
print("*" * 20) self.cur_answer = self.cur_answer + token
text = "" yield f"{self.cur_answer}"
async for token in callback.aiter(): await task
text = text + token
yield f"{text}"
await task def get_history(self):
return self.history
def updata_history(self):
self.history.append((self.cur_question, self.cur_answer))
self.crud.update_last(chat_id=self.chat_id)
self.crud.insert_turn_qa(chat_id=self.chat_id, question=self.cur_question, answer=self.cur_answer,
turn_number=len(self.history), is_last=1)
if __name__ == "__main__": if __name__ == "__main__":
print("main函数begin") # 数据库连接
print(chat("当别人想你说你好的时候,你也应该说你好", "你好")) c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2')
print(my_chat.chat("当别人想你说你好的时候,你也应该说你好", "你好"))
my_chat.updata_history()
import sys import sys
from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.c_db import UPostgresDB
from src.pgdb.chat.chat_table import Chat from src.pgdb.chat.chat_table import Chat
from src.pgdb.chat.c_user_table import CUser from src.pgdb.chat.c_user_table import CUser
from src.pgdb.chat.turn_qa_table import TurnQa from src.pgdb.chat.turn_qa_table import TurnQa
from src.pgdb.chat.crud import CRUD
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD
)
sys.path.append("../") sys.path.append("../")
"""测试会话相关数据可的连接""" """测试会话相关数据可的连接"""
def test(): def test():
c_db = UPostgresDB(host="localhost", database="laechat", user="postgres", password="chenzl", port=5432) c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
chat = Chat(db=c_db) port=CHAT_DB_PORT, )
c_user = CUser(db=c_db)
turn_qa = TurnQa(db=c_db) crud = CRUD(c_db)
crud.create_table()
chat.create_table() crud.insert_turn_qa("2", "wen4", "da1", 1, 0)
c_user.create_table() crud.insert_turn_qa("2", "wen4", "da1", 2, 0)
turn_qa.create_table() crud.insert_turn_qa("2", "wen4", "da1", 5, 0)
crud.insert_turn_qa("2", "wen4", "da1", 4, 0)
# chat_id, user_id, info, deleted crud.insert_turn_qa("2", "wen4", "da1", 3, 0)
chat.insert(["3333", "1111", "没有info", 0]) crud.insert_turn_qa("2", "wen4", "da1", 6, 0)
crud.insert_turn_qa("2", "wen4", "da1", 8, 0)
# user_id, account, password crud.insert_turn_qa("2", "wen4", "da1", 7, 0)
c_user.insert(["111", "zhangsan", "111111"]) crud.insert_turn_qa("2", "wen4", "da1", 9, 0)
# turn_id, chat_id, question, answer, turn_number, is_last print(crud.get_history('2'))
turn_qa.insert(["222", "1111", "nihao", "nihao", 1, 0])
if __name__ == "__main__": if __name__ == "__main__":
......
import gradio as gr
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">辅助生成知识库</h1>""")
# with gr.Row():
# input_text = gr.Textbox(show_label=True, placeholder="输入需要处理的文档...", lines=10)
with gr.Row():
input_text = gr.Textbox(show_label=True, placeholder="输入需要处理的文档...", lines=10, scale=9)
model_selector = gr.Dropdown(choices=["ernie", "chatglm3"], label="请选择一个模型", scale=1, min_width=50,
value="chatglm3")
with gr.Row():
num_selector = gr.Slider(minimum=0, maximum=10, value=5, label="请选择问题数量", step=1)
with gr.Row():
qaBtn = gr.Button("QA问答对生成")
demo.queue().launch(share=False, inbrowser=True,server_name="192.168.100.76",server_port=8888)
\ No newline at end of file
from functools import reduce
def add_three(x,y):
return x + y
li = [1,2,3,5]
reduce(add_three, li)#=> 11
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment