Commit b0dfbbcf by 陈正乐

restful接口初步实现

parent 54455eae
from fastapi import FastAPI, Header
from src.pgdb.chat.c_db import UPostgresDB
from fastapi.middleware.cors import CORSMiddleware
from src.controller.response import Response
from src.controller.return_data import ReturnData
from src.pgdb.chat.crud import CRUD
import uvicorn
from src.server.qa import QA
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from langchain.prompts import PromptTemplate
from src.controller.request import (
RegisterRequest,
LoginRequest,
ChatCreateRequest,
ChatQaRequest,
ChatDetailRequest,
ChatDeleteRequest
)
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER,
prompt1
)
app = FastAPI()
# 数据库连接
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
c_db.connect()
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有域访问,也可以指定特定域名
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有HTTP头
)
@app.post("/lae/user/register", response_model=Response)
async def register(request: RegisterRequest):
print(request)
account = request.account
password = request.password
crud = CRUD(_db=c_db)
if crud.user_exist_account(_account=account):
return ReturnData(40000, "当前用户已注册", {}).get_data()
crud.insert_c_user(account=account, password=password)
data = {
"account": account,
}
return ReturnData(200, "用户注册成功", dict(data)).get_data()
@app.get("/lae/user/login", response_model=Response)
async def login(request: LoginRequest):
print(request)
account = request.account
password = request.password
crud = CRUD(_db=c_db)
user_id = crud.user_exist_account_password(_account=account, _password=password)
if not user_id:
return ReturnData(40000, "用户未注册或密码错误", {}).get_data()
token = user_id + '******'
data = {
"account": account,
"token": token
}
return ReturnData(200, "用户登录成功", dict(data)).get_data()
@app.post("/lae/chat/create", response_model=Response)
async def create(request: ChatCreateRequest, token: str = Header(None)):
print(request)
print(token)
user_id = token.replace('*', '')
crud = CRUD(_db=c_db)
if not crud.user_exist_id(_user_id=user_id):
return ReturnData(40000, "当前用户暂未注册", {}).get_data()
chat_info = '这是该chat的info'
crud.insert_chat(user_id=user_id, info='这是该chat的info', deleted=0)
data = {
"user_id": user_id,
"chat_info": chat_info,
}
return ReturnData(200, '会话创建成功', dict(data)).get_data()
@app.post("/lae/chat/delete", response_model=Response)
async def delete(request: ChatDeleteRequest, token: str = Header(None)):
print(request)
print(token)
user_id = token.replace('*', '')
chat_id = request.chat_id
crud = CRUD(_db=c_db)
if crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
chat_info = crud.get_chatinfo_from_chatid(chat_id)
crud.delete_chat(chat_id)
data = {
"user_id": user_id,
"chat_info": chat_info,
}
return ReturnData(200, '会话删除成功', dict(data)).get_data()
@app.post("/lae/chat/qa", response_model=Response)
async def qa(request: ChatQaRequest, token: str = Header(None)):
print(request)
print(token)
user_id = token.replace('*', '')
chat_id = request.chat_id
question = request.query
crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id,
_faiss_db=vecstore_faiss)
answer = my_chat.chat(question)
my_chat.update_history()
data = {
"answer": answer
}
return ReturnData(200, '会话创建成功', dict(data)).get_data()
@app.get("/lae/chat/detail", response_model=Response)
async def detail(request: ChatDetailRequest, token: str = Header(None)):
print(request)
print(token)
user_id = token.replace('*', '')
chat_id = request.chat_id
crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
history = crud.get_history(chat_id)
data = {
"chat_id": chat_id,
"history": history
}
return ReturnData(200, "会话详情获取成功", dict(data)).get_data()
@app.post("/lae/chat/clist", response_model=Response)
async def clist(token: str = Header(None)):
print(token)
user_id = token.replace('*', '')
crud = CRUD(_db=c_db)
chat_list = crud.get_chat_list_userid(user_id)
data = {
"chat_list": chat_list
}
return ReturnData(200, "会话列表获取成功", dict(data)).get_data()
if __name__ == "__main__":
uvicorn.run(app, host='localhost', port=8889)
from pydantic import BaseModel
class LoginRequest(BaseModel):
account: str
password: str
class RegisterRequest(BaseModel):
account: str
password: str
class ChatCreateRequest(BaseModel):
id: str
class ChatQaRequest(BaseModel):
chat_id: str = None
query: str
class ChatDetailRequest(BaseModel):
chat_id: str
class ChatDeleteRequest(BaseModel):
chat_id: str
from pydantic import BaseModel
class Response(BaseModel):
code: int
message: str
data: dict
\ No newline at end of file
class ReturnData:
def __init__(self, code: int, message: str, data: dict) -> None:
self.code = code
self.message = message
self.data = data
def get_data(self):
return {
"code": self.code,
"message": self.message,
"data": self.data
}
......@@ -106,3 +106,38 @@ class CRUD:
def update_last(self, chat_id):
query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND is_last = 1'
self.db.execute_args(query, (chat_id,))
def user_exist_id(self, _user_id):
query = f'SELECT * FROM c_user WHERE user_id = (%s)'
self.db.execute_args(query, (_user_id,))
return self.db.fetchone()
def user_exist_account(self, _account):
query = f'SELECT * FROM c_user WHERE account = (%s)'
self.db.execute_args(query, (_account,))
return self.db.fetchone()
def user_exist_account_password(self, _account, _password):
query = f'SELECT user_id FROM c_user WHERE account = (%s) AND password = (%s)'
self.db.execute_args(query, (_account, _password))
return self.db.fetchone()
def chat_exist_chatid_userid(self, _chat_id, _user_id):
query = f'SELECT * FROM chat WHERE chat_id = (%s) AND user_id = (%s)'
self.db.execute_args(query, (_chat_id, _user_id))
return self.db.fetchone()
def get_chat_list_userid(self, _user_id):
query = f'SELECT info FROM chat WHERE user_id = (%s) AND deleted = 0 order by create_time desc'
self.db.execute_args(query, (_user_id,))
return self.db.fetchall()
def get_chatinfo_from_chatid(self, _chat_id):
query = f'SELECT info FROM chat WHERE chat_id = (%s)'
self.db.execute_args(query, (_chat_id,))
return self.db.fetchone()
def delete_chat(self, _chat_id):
query = f'UPDATE chat SET deleted = 1 WHERE chat_id = (%s)'
self.db.execute_args(query, (_chat_id,))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment