from fastapi import FastAPI, Header
from src.pgdb.chat.c_db import UPostgresDB
from fastapi.middleware.cors import CORSMiddleware
from src.controller.response import Response
from src.controller.return_data import ReturnData
from src.pgdb.chat.crud import CRUD
import uvicorn
from src.server.qa import QA
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.llm.ernie_with_sdk import ChatERNIESerLLM
from qianfan import ChatCompletion
from langchain_core.prompts import PromptTemplate
from src.controller.request import (
    RegisterRequest,
    LoginRequest,
    ChatQaRequest,
    ChatDetailRequest,
    ChatDeleteRequest,
    ChatReQA
)
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
    CHAT_DB_PASSWORD,
    EMBEEDING_MODEL_PATH,
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
    SIMILARITY_SHOW_NUMBER,
    prompt1
)

app = FastAPI()
# 数据库连接
c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                   port=CHAT_DB_PORT, )
c_db.connect()
# 添加 CORS 中间件
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 允许所有域访问,也可以指定特定域名
    allow_credentials=True,
    allow_methods=["*"],  # 允许所有HTTP方法
    allow_headers=["*"],  # 允许所有HTTP头
)


@app.post("/lae/user/register", response_model=Response)
async def register(request: RegisterRequest):
    print(request)
    account = request.account
    password = request.password
    crud = CRUD(_db=c_db)
    if crud.user_exist_account(_account=account):
        return ReturnData(40000, "当前用户已注册", {}).get_data()
    crud.insert_c_user(account=account, password=password)
    data = {
        "account": account,
    }
    return ReturnData(200, "用户注册成功", dict(data)).get_data()


@app.get("/lae/user/login", response_model=Response)
async def login(request: LoginRequest):
    print(request)
    account = request.account
    password = request.password
    crud = CRUD(_db=c_db)
    user_id = crud.user_exist_account_password(_account=account, _password=password)
    if not user_id:
        return ReturnData(40000, "用户未注册或密码错误", {}).get_data()
    token = user_id + '******'
    data = {
        "account": account,
        "token": token
    }
    return ReturnData(200, "用户登录成功", dict(data)).get_data()


@app.post("/lae/chat/create", response_model=Response)
async def create(token: str = Header(None)):
    print(token)
    user_id = token.replace('*', '')
    crud = CRUD(_db=c_db)
    if not crud.user_exist_id(_user_id=user_id):
        return ReturnData(40000, "当前用户暂未注册", {}).get_data()
    chat_info = '这是该chat的info'
    crud.insert_chat(user_id=user_id, info='这是该chat的info', deleted=0)
    data = {
        "user_id": user_id,
        "chat_info": chat_info,
    }
    return ReturnData(200, '会话创建成功', dict(data)).get_data()


@app.post("/lae/chat/delete", response_model=Response)
async def delete(request: ChatDeleteRequest, token: str = Header(None)):
    print(request)
    print(token)
    user_id = token.replace('*', '')
    chat_id = request.chat_id
    crud = CRUD(_db=c_db)
    if crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
        return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
    chat_info = crud.get_chatinfo_from_chatid(chat_id)
    crud.delete_chat(chat_id)
    data = {
        "user_id": user_id,
        "chat_info": chat_info,
    }
    return ReturnData(200, '会话删除成功', dict(data)).get_data()


@app.post("/lae/chat/qa", response_model=Response)
async def qa(request: ChatQaRequest, token: str = Header(None)):
    print(request)
    print(token)
    user_id = token.replace('*', '')
    chat_id = request.chat_id
    question = request.question
    crud = CRUD(_db=c_db)
    if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
        return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
    vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
        store_path=FAISS_STORE_PATH,
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
    my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
                 {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id,
                 _faiss_db=vecstore_faiss)
    answer = my_chat.chat(question)
    my_chat.update_history()
    data = {
        "answer": answer
    }
    return ReturnData(200, '模型问答成功', dict(data)).get_data()


@app.get("/lae/chat/detail", response_model=Response)
async def detail(request: ChatDetailRequest, token: str = Header(None)):
    print(request)
    print(token)
    user_id = token.replace('*', '')
    chat_id = request.chat_id
    crud = CRUD(_db=c_db)
    if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
        return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
    history = crud.get_history(chat_id)
    data = {
        "chat_id": chat_id,
        "history": history
    }
    return ReturnData(200, "会话详情获取成功", dict(data)).get_data()


@app.post("/lae/chat/clist", response_model=Response)
async def clist(token: str = Header(None)):
    print(token)
    user_id = token.replace('*', '')
    crud = CRUD(_db=c_db)
    chat_list = crud.get_chat_list_userid(user_id)
    data = {
        "chat_list": chat_list
    }
    return ReturnData(200, "会话列表获取成功", dict(data)).get_data()


@app.post("/lae/chat/reqa", response_model=Response)
async def reqa(request: ChatReQA, token: str = Header(None)):
    print(request)
    print(token)
    chat_id = request.chat_id
    user_id = token.replace('*', '')
    crud = CRUD(_db=c_db)
    if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
        return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
    question = crud.get_last_question(_chat_id=chat_id)
    vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
        store_path=FAISS_STORE_PATH,
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)
    base_llm = ChatERNIESerLLM(
        chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
    my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
                 {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id,
                 _faiss_db=vecstore_faiss)
    answer = my_chat.chat(question)
    my_chat.update_history()
    data = {
        "answer": answer
    }
    return ReturnData(200, '模型重新问答成功', dict(data)).get_data()

if __name__ == "__main__":
    uvicorn.run(app, host='localhost', port=8889)