Commit 4f4401f4 by 文靖昊

修改接口,使用fastapi替换flask

parent 7aad8ddc
...@@ -27,3 +27,17 @@ class ChatDeleteRequest(BaseModel): ...@@ -27,3 +27,17 @@ class ChatDeleteRequest(BaseModel):
class ChatReQA(BaseModel): class ChatReQA(BaseModel):
chat_id: str chat_id: str
query: str query: str
class PhoneLoginRequest(BaseModel):
phone: str
class ChatRequest(BaseModel):
sessionID: str = ""
question: str
class ReGenerateRequest(BaseModel):
sessionID: str
question: str
import sys import sys
sys.path.append('../') sys.path.append('../')
import flask from fastapi import FastAPI, Header
from flask import request from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime,timedelta from datetime import datetime,timedelta
from src.pgdb.chat.c_db import UPostgresDB from src.pgdb.chat.c_db import UPostgresDB
from flask_cors import CORS import uvicorn
import json import json
from src.pgdb.chat.crud import CRUD from src.pgdb.chat.crud import CRUD
from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.qa import QA from src.server.qa import QA
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from src.controller.request import (
PhoneLoginRequest,
ChatRequest,
ReGenerateRequest
)
from src.config.consts import ( from src.config.consts import (
CHAT_DB_USER, CHAT_DB_USER,
CHAT_DB_HOST, CHAT_DB_HOST,
...@@ -28,8 +33,14 @@ from src.config.consts import ( ...@@ -28,8 +33,14 @@ from src.config.consts import (
SIMILARITY_SHOW_NUMBER, SIMILARITY_SHOW_NUMBER,
prompt1 prompt1
) )
app=flask.Flask(__name__) app = FastAPI()
CORS(app, supports_credentials=True) app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有域访问,也可以指定特定域名
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有HTTP头
)
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD, c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, ) port=CHAT_DB_PORT, )
...@@ -53,11 +64,9 @@ base_llm = ChatOpenAI( ...@@ -53,11 +64,9 @@ base_llm = ChatOpenAI(
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm, my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss) {"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
@app.route('/api/login', methods=['POST']) @app.post('/api/login')
def login(): def login(phone_request: PhoneLoginRequest):
raw_data = request.data phone = phone_request.phone
data = json.loads(raw_data)
phone = data["phone"]
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
user = crud.user_exist_account(phone) user = crud.user_exist_account(phone)
if not user: if not user:
...@@ -78,9 +87,8 @@ def login(): ...@@ -78,9 +87,8 @@ def login():
} }
@app.route('/api/sessions/<type>', methods=['GET']) @app.get('/api/sessions/{type}')
def get_sessions(type): def get_sessions(type: int,token: str = Header(None)):
token = request.headers.get('token')
if not token: if not token:
return { return {
'code': 404, 'code': 404,
...@@ -91,17 +99,14 @@ def get_sessions(type): ...@@ -91,17 +99,14 @@ def get_sessions(type):
chat_list_str = [] chat_list_str = []
for chat in chat_list: for chat in chat_list:
chat_list_str.append(str(chat[0])) chat_list_str.append(str(chat[0]))
print(chat_list_str)
print(chat_list)
return { return {
'code': 200, 'code': 200,
'data': chat_list_str 'data': chat_list_str
} }
@app.route('/api/session/<session_id>', methods=['GET']) @app.get('/api/session/{session_id}')
def get_history_by_session_id(session_id): def get_history_by_session_id(session_id:str,token: str = Header(None)):
token = request.headers.get('token')
if not token: if not token:
return { return {
'code': 404, 'code': 404,
...@@ -124,9 +129,8 @@ def get_history_by_session_id(session_id): ...@@ -124,9 +129,8 @@ def get_history_by_session_id(session_id):
'data': history_str 'data': history_str
} }
@app.route('/api/session/<session_id>', methods=['DELETE']) @app.delete('/api/session/{session_id}')
def delete_session_by_session_id(session_id): def delete_session_by_session_id(session_id:str,token: str = Header(None)):
token = request.headers.get('token')
if not token: if not token:
return { return {
'code': 404, 'code': 404,
...@@ -140,32 +144,29 @@ def delete_session_by_session_id(session_id): ...@@ -140,32 +144,29 @@ def delete_session_by_session_id(session_id):
} }
@app.route('/api/general/chat', methods=['POST']) @app.post('/api/general/chat')
def question(): def question(chat_request: ChatRequest, token: str = Header(None)):
token = request.headers.get('token')
if not token: if not token:
return { return {
'code': 404, 'code': 404,
'data': '验证失败' 'data': '验证失败'
} }
raw_data = request.data session_id = chat_request.sessionID
data = json.loads(raw_data) question = chat_request.question
if "sessionID" in data:
session_id = data["sessionID"]
else:
session_id = ""
print(session_id)
question = data["question"]
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
history = [] history = []
if session_id != "": if session_id =="":
history = crud.get_last_history(str(session_id)) history = crud.get_last_history(str(session_id))
print(history)
# answer = my_chat.chat(question) # answer = my_chat.chat(question)
answer, docs = my_chat.chat(question, with_similarity=True) answer, docs = my_chat.chat(question, with_similarity=True)
print(docs) docs_json = []
for d in docs:
j ={}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
docs_json.append(j)
# answer = "test Answer" # answer = "test Answer"
if session_id == "": if session_id =="":
session_id = crud.create_chat(token, '\t\t', '0') session_id = crud.create_chat(token, '\t\t', '0')
crud.insert_turn_qa(session_id, question, answer, 0, 1) crud.insert_turn_qa(session_id, question, answer, 0, 1)
else: else:
...@@ -177,44 +178,46 @@ def question(): ...@@ -177,44 +178,46 @@ def question():
'data': { 'data': {
'question': question, 'question': question,
'answer': answer, 'answer': answer,
'sessionID': session_id 'sessionID': session_id,
'similarity': docs_json
} }
} }
@app.route('/api/general/regenerate', methods=['POST'])
def re_generate(): @app.post('/api/general/regenerate')
token = request.headers.get('token') def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
if not token: if not token:
return { return {
'code': 404, 'code': 404,
'data': '验证失败' 'data': '验证失败'
} }
raw_data = request.data session_id = chat_request.sessionID
data = json.loads(raw_data) question = chat_request.question
question = data["question"]
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
history = []
if "sessionID" in data:
session_id = data["sessionID"]
else:
session_id = ""
if session_id != "":
history = crud.get_last_history(str(session_id))
print(history)
answer = my_chat.chat(question)
# answer = "reGenerate Answer"
last_turn_id = crud.get_last_turn_num(str(session_id)) last_turn_id = crud.get_last_turn_num(str(session_id))
history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
answer, docs = my_chat.chat(question, with_similarity=True)
docs_json = []
for d in docs:
j = {}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
docs_json.append(j)
# answer = "reGenerate Answer"
crud.update_turn_last(str(session_id), last_turn_id ) crud.update_turn_last(str(session_id), last_turn_id )
crud.insert_turn_qa(session_id, question, answer, last_turn_id , 1) crud.insert_turn_qa(session_id, question, answer, last_turn_id , 1)
return { return {
'code': 200, 'code': 200,
'data': { 'data': {
'question': question, 'question': question,
'answer': answer, 'answer': answer,
'sessionID': session_id 'sessionID': session_id,
'similarity': docs_json
} }
} }
if __name__ == "__main__": if __name__ == "__main__":
app.run(host="0.0.0.0",port=8088,debug=False) uvicorn.run(app, host='0.0.0.0', port=8088)
\ No newline at end of file
...@@ -97,6 +97,12 @@ class CRUD: ...@@ -97,6 +97,12 @@ class CRUD:
ans = self.db.fetchall() ans = self.db.fetchall()
return ans return ans
def get_last_history_before_turn_id(self, _chat_id,turn_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) and is_last=1 and turn_number<(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,turn_id))
ans = self.db.fetchall()
return ans
def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last): def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last):
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)' query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last)) self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment