from fastapi import FastAPI, Header from src.pgdb.chat.c_db import UPostgresDB from fastapi.middleware.cors import CORSMiddleware from src.controller.response import Response from src.controller.return_data import ReturnData from src.pgdb.chat.crud import CRUD import uvicorn from src.server.qa import QA from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.llm.ernie_with_sdk import ChatERNIESerLLM from qianfan import ChatCompletion from langchain_core.prompts import PromptTemplate from src.controller.request import ( RegisterRequest, LoginRequest, ChatQaRequest, ChatDetailRequest, ChatDeleteRequest, ChatReQA ) from src.config.consts import ( CHAT_DB_USER, CHAT_DB_HOST, CHAT_DB_PORT, CHAT_DB_DBNAME, CHAT_DB_PASSWORD, EMBEEDING_MODEL_PATH, FAISS_STORE_PATH, INDEX_NAME, VEC_DB_HOST, VEC_DB_PASSWORD, VEC_DB_PORT, VEC_DB_USER, VEC_DB_DBNAME, SIMILARITY_SHOW_NUMBER, prompt1 ) app = FastAPI() # 数据库连接 c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD, port=CHAT_DB_PORT, ) c_db.connect() # 添加 CORS 中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 允许所有域访问,也可以指定特定域名 allow_credentials=True, allow_methods=["*"], # 允许所有HTTP方法 allow_headers=["*"], # 允许所有HTTP头 ) @app.post("/lae/user/register", response_model=Response) async def register(request: RegisterRequest): print(request) account = request.account password = request.password crud = CRUD(_db=c_db) if crud.user_exist_account(_account=account): return ReturnData(40000, "当前用户已注册", {}).get_data() crud.insert_c_user(account=account, password=password) data = { "account": account, } return ReturnData(200, "用户注册成功", dict(data)).get_data() @app.get("/lae/user/login", response_model=Response) async def login(request: LoginRequest): print(request) account = request.account password = request.password crud = CRUD(_db=c_db) user_id = crud.user_exist_account_password(_account=account, _password=password) if not user_id: return ReturnData(40000, "用户未注册或密码错误", {}).get_data() token = user_id + '******' data = { "account": account, "token": token } return ReturnData(200, "用户登录成功", dict(data)).get_data() @app.post("/lae/chat/create", response_model=Response) async def create(token: str = Header(None)): print(token) user_id = token.replace('*', '') crud = CRUD(_db=c_db) if not crud.user_exist_id(_user_id=user_id): return ReturnData(40000, "当前用户暂未注册", {}).get_data() chat_info = '这是该chat的info' crud.insert_chat(user_id=user_id, info='这是该chat的info', deleted=0) data = { "user_id": user_id, "chat_info": chat_info, } return ReturnData(200, '会话创建成功', dict(data)).get_data() @app.post("/lae/chat/delete", response_model=Response) async def delete(request: ChatDeleteRequest, token: str = Header(None)): print(request) print(token) user_id = token.replace('*', '') chat_id = request.chat_id crud = CRUD(_db=c_db) if crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id): return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data() chat_info = crud.get_chatinfo_from_chatid(chat_id) crud.delete_chat(chat_id) data = { "user_id": user_id, "chat_info": chat_info, } return ReturnData(200, '会话删除成功', dict(data)).get_data() @app.post("/lae/chat/qa", response_model=Response) async def qa(request: ChatQaRequest, token: str = Header(None)): print(request) print(token) user_id = token.replace('*', '') chat_id = request.chat_id question = request.question crud = CRUD(_db=c_db) if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id): return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data() vecstore_faiss = VectorStore_FAISS( embedding_model_name=EMBEEDING_MODEL_PATH, store_path=FAISS_STORE_PATH, index_name=INDEX_NAME, info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER, "password": VEC_DB_PASSWORD}, show_number=SIMILARITY_SHOW_NUMBER, reset=False) base_llm = ChatERNIESerLLM( chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id, _faiss_db=vecstore_faiss) answer = my_chat.chat(question) my_chat.update_history() data = { "answer": answer } return ReturnData(200, '模型问答成功', dict(data)).get_data() @app.get("/lae/chat/detail", response_model=Response) async def detail(request: ChatDetailRequest, token: str = Header(None)): print(request) print(token) user_id = token.replace('*', '') chat_id = request.chat_id crud = CRUD(_db=c_db) if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id): return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data() history = crud.get_history(chat_id) data = { "chat_id": chat_id, "history": history } return ReturnData(200, "会话详情获取成功", dict(data)).get_data() @app.post("/lae/chat/clist", response_model=Response) async def clist(token: str = Header(None)): print(token) user_id = token.replace('*', '') crud = CRUD(_db=c_db) chat_list = crud.get_chat_list_userid(user_id) data = { "chat_list": chat_list } return ReturnData(200, "会话列表获取成功", dict(data)).get_data() @app.post("/lae/chat/reqa", response_model=Response) async def reqa(request: ChatReQA, token: str = Header(None)): print(request) print(token) chat_id = request.chat_id user_id = token.replace('*', '') crud = CRUD(_db=c_db) if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id): return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data() question = crud.get_last_question(_chat_id=chat_id) vecstore_faiss = VectorStore_FAISS( embedding_model_name=EMBEEDING_MODEL_PATH, store_path=FAISS_STORE_PATH, index_name=INDEX_NAME, info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER, "password": VEC_DB_PASSWORD}, show_number=SIMILARITY_SHOW_NUMBER, reset=False) base_llm = ChatERNIESerLLM( chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id, _faiss_db=vecstore_faiss) answer = my_chat.chat(question) my_chat.update_history() data = { "answer": answer } return ReturnData(200, '模型重新问答成功', dict(data)).get_data() if __name__ == "__main__": uvicorn.run(app, host='localhost', port=8889)