Commit 4f4401f4 by 文靖昊

修改接口,使用fastapi替换flask

parent 7aad8ddc
......@@ -27,3 +27,17 @@ class ChatDeleteRequest(BaseModel):
class ChatReQA(BaseModel):
chat_id: str
query: str
class PhoneLoginRequest(BaseModel):
phone: str
class ChatRequest(BaseModel):
sessionID: str = ""
question: str
class ReGenerateRequest(BaseModel):
sessionID: str
question: str
import sys
sys.path.append('../')
import flask
from flask import request
from fastapi import FastAPI, Header
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime,timedelta
from src.pgdb.chat.c_db import UPostgresDB
from flask_cors import CORS
import uvicorn
import json
from src.pgdb.chat.crud import CRUD
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.qa import QA
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from src.controller.request import (
PhoneLoginRequest,
ChatRequest,
ReGenerateRequest
)
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
......@@ -28,8 +33,14 @@ from src.config.consts import (
SIMILARITY_SHOW_NUMBER,
prompt1
)
app=flask.Flask(__name__)
CORS(app, supports_credentials=True)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有域访问,也可以指定特定域名
allow_credentials=True,
allow_methods=["*"], # 允许所有HTTP方法
allow_headers=["*"], # 允许所有HTTP头
)
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
......@@ -53,11 +64,9 @@ base_llm = ChatOpenAI(
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
@app.route('/api/login', methods=['POST'])
def login():
raw_data = request.data
data = json.loads(raw_data)
phone = data["phone"]
@app.post('/api/login')
def login(phone_request: PhoneLoginRequest):
phone = phone_request.phone
crud = CRUD(_db=c_db)
user = crud.user_exist_account(phone)
if not user:
......@@ -78,9 +87,8 @@ def login():
}
@app.route('/api/sessions/<type>', methods=['GET'])
def get_sessions(type):
token = request.headers.get('token')
@app.get('/api/sessions/{type}')
def get_sessions(type: int,token: str = Header(None)):
if not token:
return {
'code': 404,
......@@ -91,17 +99,14 @@ def get_sessions(type):
chat_list_str = []
for chat in chat_list:
chat_list_str.append(str(chat[0]))
print(chat_list_str)
print(chat_list)
return {
'code': 200,
'data': chat_list_str
}
@app.route('/api/session/<session_id>', methods=['GET'])
def get_history_by_session_id(session_id):
token = request.headers.get('token')
@app.get('/api/session/{session_id}')
def get_history_by_session_id(session_id:str,token: str = Header(None)):
if not token:
return {
'code': 404,
......@@ -124,9 +129,8 @@ def get_history_by_session_id(session_id):
'data': history_str
}
@app.route('/api/session/<session_id>', methods=['DELETE'])
def delete_session_by_session_id(session_id):
token = request.headers.get('token')
@app.delete('/api/session/{session_id}')
def delete_session_by_session_id(session_id:str,token: str = Header(None)):
if not token:
return {
'code': 404,
......@@ -140,32 +144,29 @@ def delete_session_by_session_id(session_id):
}
@app.route('/api/general/chat', methods=['POST'])
def question():
token = request.headers.get('token')
@app.post('/api/general/chat')
def question(chat_request: ChatRequest, token: str = Header(None)):
if not token:
return {
'code': 404,
'data': '验证失败'
}
raw_data = request.data
data = json.loads(raw_data)
if "sessionID" in data:
session_id = data["sessionID"]
else:
session_id = ""
print(session_id)
question = data["question"]
session_id = chat_request.sessionID
question = chat_request.question
crud = CRUD(_db=c_db)
history = []
if session_id != "":
if session_id =="":
history = crud.get_last_history(str(session_id))
print(history)
# answer = my_chat.chat(question)
answer, docs = my_chat.chat(question, with_similarity=True)
print(docs)
docs_json = []
for d in docs:
j ={}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
docs_json.append(j)
# answer = "test Answer"
if session_id == "":
if session_id =="":
session_id = crud.create_chat(token, '\t\t', '0')
crud.insert_turn_qa(session_id, question, answer, 0, 1)
else:
......@@ -177,44 +178,46 @@ def question():
'data': {
'question': question,
'answer': answer,
'sessionID': session_id
'sessionID': session_id,
'similarity': docs_json
}
}
@app.route('/api/general/regenerate', methods=['POST'])
def re_generate():
token = request.headers.get('token')
@app.post('/api/general/regenerate')
def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
if not token:
return {
'code': 404,
'data': '验证失败'
}
raw_data = request.data
data = json.loads(raw_data)
question = data["question"]
session_id = chat_request.sessionID
question = chat_request.question
crud = CRUD(_db=c_db)
history = []
if "sessionID" in data:
session_id = data["sessionID"]
else:
session_id = ""
if session_id != "":
history = crud.get_last_history(str(session_id))
print(history)
answer = my_chat.chat(question)
# answer = "reGenerate Answer"
last_turn_id = crud.get_last_turn_num(str(session_id))
history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
answer, docs = my_chat.chat(question, with_similarity=True)
docs_json = []
for d in docs:
j = {}
j["page_content"] = d.page_content
j["from_file"] = d.metadata["filename"]
docs_json.append(j)
# answer = "reGenerate Answer"
crud.update_turn_last(str(session_id), last_turn_id )
crud.insert_turn_qa(session_id, question, answer, last_turn_id , 1)
return {
'code': 200,
'data': {
'question': question,
'answer': answer,
'sessionID': session_id
'sessionID': session_id,
'similarity': docs_json
}
}
if __name__ == "__main__":
app.run(host="0.0.0.0",port=8088,debug=False)
\ No newline at end of file
uvicorn.run(app, host='0.0.0.0', port=8088)
......@@ -97,6 +97,12 @@ class CRUD:
ans = self.db.fetchall()
return ans
def get_last_history_before_turn_id(self, _chat_id,turn_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) and is_last=1 and turn_number<(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,turn_id))
ans = self.db.fetchall()
return ans
def insert_turn_qa(self, chat_id, question, answer, turn_number, is_last):
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment