Commit 6a841de0 by 文靖昊

适配浦发前端,构建api服务

parent 3a0c2544
import flask
from flask import request
from datetime import datetime,timedelta
from src.pgdb.chat.c_db import UPostgresDB
from flask_cors import CORS
import json
from src.pgdb.chat.crud import CRUD
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.qa import QA
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from src.config.consts import (
CHAT_DB_USER,
CHAT_DB_HOST,
CHAT_DB_PORT,
CHAT_DB_DBNAME,
CHAT_DB_PASSWORD,
EMBEEDING_MODEL_PATH,
FAISS_STORE_PATH,
INDEX_NAME,
VEC_DB_HOST,
VEC_DB_PASSWORD,
VEC_DB_PORT,
VEC_DB_USER,
VEC_DB_DBNAME,
SIMILARITY_SHOW_NUMBER,
prompt1
)
app=flask.Flask(__name__)
CORS(app, supports_credentials=True)
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
port=CHAT_DB_PORT, )
c_db.connect()
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base='http://192.168.10.14:8000/v1',
model_name='Qwen2-7B',
verbose=True
)
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _faiss_db=vecstore_faiss)
@app.route('/api/login', methods=['POST'])
def login():
raw_data = request.data
data = json.loads(raw_data)
phone = data["phone"]
crud = CRUD(_db=c_db)
user = crud.user_exist_account(phone)
if not user:
crud.insert_c_user(phone,"123456")
user = crud.user_exist_account(phone)
userid = user[0]
expire = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
return{
'code': 200,
'data': {
'accessToken': userid,
'refreshToken': userid,
'accessExpire': expire,
'refreshExpire': expire,
'accessUUID': 'accessUUID',
'refreshUUID': 'refreshUUID',
}
}
@app.route('/api/sessions/<type>', methods=['GET'])
def get_sessions(type):
token = request.headers.get('token')
if not token:
return {
'code': 404,
'data': '验证失败'
}
crud = CRUD(_db=c_db)
chat_list = crud.get_chat_list_userid(token)
chat_list_str = []
for chat in chat_list:
chat_list_str.append(str(chat[0]))
print(chat_list_str)
print(chat_list)
return {
'code': 200,
'data': chat_list_str
}
@app.route('/api/session/<session_id>', methods=['GET'])
def get_history_by_session_id(session_id):
token = request.headers.get('token')
if not token:
return {
'code': 404,
'data': '验证失败'
}
crud = CRUD(_db=c_db)
history = crud.get_history(session_id)
history_json = []
for h in history:
j ={}
j["TurnID"] = h[0]
j["Question"] = h[1]
j["Answer"] = h[2]
j["IsLast"] = h[3]
j["SimilarDocuments"] =[]
history_json.append(j)
history_str = json.dumps(history_json)
return {
'code': 200,
'data': history_str
}
@app.route('/api/session/<session_id>', methods=['DELETE'])
def delete_session_by_session_id(session_id):
token = request.headers.get('token')
if not token:
return {
'code': 404,
'data': '验证失败'
}
crud = CRUD(_db=c_db)
crud.delete_chat(session_id)
return {
'code': 200,
'data': 'success'
}
@app.route('/api/general/chat', methods=['POST'])
def question():
token = request.headers.get('token')
if not token:
return {
'code': 404,
'data': '验证失败'
}
raw_data = request.data
data = json.loads(raw_data)
if "sessionID" in data:
session_id = data["sessionID"]
else:
session_id = ""
print(session_id)
question = data["question"]
crud = CRUD(_db=c_db)
history = []
if session_id != "":
history = crud.get_last_history(str(session_id))
print(history)
answer = my_chat.chat(question)
# answer = "test Answer"
if session_id == "":
session_id = crud.create_chat(token, '\t\t', '0')
crud.insert_turn_qa(session_id, question, answer, 0, 1)
else:
last_turn_id = crud.get_last_turn_num(str(session_id))
crud.insert_turn_qa(session_id, question, answer, last_turn_id+1, 1)
return {
'code': 200,
'data': {
'question': question,
'answer': answer,
'sessionID': session_id
}
}
@app.route('/api/general/regenerate', methods=['POST'])
def re_generate():
token = request.headers.get('token')
if not token:
return {
'code': 404,
'data': '验证失败'
}
raw_data = request.data
data = json.loads(raw_data)
question = data["question"]
crud = CRUD(_db=c_db)
history = []
if "sessionID" in data:
session_id = data["sessionID"]
else:
session_id = ""
if session_id != "":
history = crud.get_last_history(str(session_id))
print(history)
answer = my_chat.chat(question)
# answer = "reGenerate Answer"
last_turn_id = crud.get_last_turn_num(str(session_id))
crud.update_turn_last(str(session_id), last_turn_id )
crud.insert_turn_qa(session_id, question, answer, last_turn_id , 1)
return {
'code': 200,
'data': {
'question': question,
'answer': answer,
'sessionID': session_id
}
}
if __name__ == "__main__":
app.run(host="0.0.0.0",port=8088,debug=False)
\ No newline at end of file
......@@ -35,7 +35,11 @@ def load(filepath, mode: str = None, sentence_size: int = 0, metadata=None, call
else:
loader = UnstructuredFileLoader(filepath, mode=mode or "elements", **kwargs)
if sentence_size > 0:
return split(loader.load(), sentence_size)
try:
return split(loader.load(), sentence_size)
except:
print(filepath, " is wrong ")
return []
return loader.load()
......
......@@ -86,7 +86,13 @@ class CRUD:
self.db.execute(TABLE_USER)
def get_history(self, _chat_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
query = f'SELECT turn_number,question,answer,is_last FROM turn_qa WHERE chat_id=(%s) ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall()
return ans
def get_last_history(self, _chat_id):
query = f'SELECT question,answer FROM turn_qa WHERE chat_id=(%s) and is_last=1 ORDER BY turn_number ASC'
self.db.execute_args(query, (_chat_id,))
ans = self.db.fetchall()
return ans
......@@ -95,6 +101,18 @@ class CRUD:
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)'
self.db.execute_args(query, (chat_id, question, answer, turn_number, is_last))
def insert_turn_qa_return_id(self, chat_id, question, answer, turn_number, is_last):
conn = self.db.conn
cur = conn.cursor
query = f'INSERT INTO turn_qa(chat_id, question, answer, turn_number, is_last) VALUES (%s,%s,%s,%s,%s)'
cur.execute_args(query, (chat_id, question, answer, turn_number, is_last))
inserted_id = cur.fetchone()[0]
conn.commit()
# 关闭游标和数据库连接
cur.close()
conn.close()
return inserted_id
def insert_c_user(self, account, password):
query = f'INSERT INTO c_user(account, password) VALUES (%s,%s)'
self.db.execute_args(query, (account, password))
......@@ -107,6 +125,10 @@ class CRUD:
query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND is_last = 1'
self.db.execute_args(query, (chat_id,))
def update_turn_last(self, chat_id, turn_number):
query = f'UPDATE turn_qa SET is_last = 0 WHERE chat_id = (%s) AND turn_number = (%s)'
self.db.execute_args(query, (chat_id, turn_number))
def user_exist_id(self, _user_id):
query = f'SELECT * FROM c_user WHERE user_id = (%s)'
self.db.execute_args(query, (_user_id,))
......@@ -128,7 +150,7 @@ class CRUD:
return self.db.fetchone()
def get_chat_list_userid(self, _user_id):
query = f'SELECT info FROM chat WHERE user_id = (%s) AND deleted = 0 order by create_time desc'
query = f'SELECT chat_id FROM chat WHERE user_id = (%s) AND deleted = 0 order by create_time desc'
self.db.execute_args(query, (_user_id,))
return self.db.fetchall()
......@@ -180,3 +202,8 @@ class CRUD:
query = f'UPDATE chat SET info = (%s) WHERE chat_id = (%s)'
self.db.execute_args(query, (info, chat_id))
def get_last_turn_num(self,chat_id):
query = f'SELECT max(turn_number) FROM turn_qa WHERE chat_id = (%s)'
self.db.execute_args(query, (chat_id,))
ans = self.db.fetchone()[0]
return ans
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment