web.py 6.69 KB
Newer Older
1
import sys
2

3
sys.path.append('../')
4
from fastapi import FastAPI, Header,Query
5
from fastapi.middleware.cors import CORSMiddleware
6 7
from datetime import datetime,timedelta
from src.pgdb.chat.c_db import UPostgresDB
8
import uvicorn
9 10 11 12
import json
from src.pgdb.chat.crud import CRUD
from src.pgdb.knowledge.similarity import VectorStore_FAISS
from src.server.qa import QA
13
from langchain_core.prompts import PromptTemplate
14
from langchain_openai import ChatOpenAI
15 16 17 18 19
from src.controller.request import (
    PhoneLoginRequest,
    ChatRequest,
    ReGenerateRequest
)
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
from src.config.consts import (
    CHAT_DB_USER,
    CHAT_DB_HOST,
    CHAT_DB_PORT,
    CHAT_DB_DBNAME,
    CHAT_DB_PASSWORD,
    EMBEEDING_MODEL_PATH,
    FAISS_STORE_PATH,
    INDEX_NAME,
    VEC_DB_HOST,
    VEC_DB_PASSWORD,
    VEC_DB_PORT,
    VEC_DB_USER,
    VEC_DB_DBNAME,
    SIMILARITY_SHOW_NUMBER,
35
    prompt_enhancement_history_template
36
)
37 38 39 40 41 42 43 44
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 允许所有域访问,也可以指定特定域名
    allow_credentials=True,
    allow_methods=["*"],  # 允许所有HTTP方法
    allow_headers=["*"],  # 允许所有HTTP头
)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD,
                   port=CHAT_DB_PORT, )
c_db.connect()

vecstore_faiss = VectorStore_FAISS(
        embedding_model_name=EMBEEDING_MODEL_PATH,
        store_path=FAISS_STORE_PATH,
        index_name=INDEX_NAME,
        info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
              "password": VEC_DB_PASSWORD},
        show_number=SIMILARITY_SHOW_NUMBER,
        reset=False)

base_llm = ChatOpenAI(
    openai_api_key='xxxxxxxxxxxxx',
    openai_api_base='http://192.168.10.14:8000/v1',
    model_name='Qwen2-7B',
    verbose=True
)
65 66
my_chat = QA(PromptTemplate(input_variables=["history","context", "question"], template=prompt_enhancement_history_template), base_llm,
             {"temperature": 0.9}, ['history','context', 'question'], _db=c_db, _faiss_db=vecstore_faiss,rerank=True)
67

68 69 70
@app.post('/api/login')
def login(phone_request: PhoneLoginRequest):
    phone = phone_request.phone
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    crud = CRUD(_db=c_db)
    user = crud.user_exist_account(phone)
    if not user:
        crud.insert_c_user(phone,"123456")
        user = crud.user_exist_account(phone)
    userid = user[0]
    expire = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S')
    return{
        'code': 200,
        'data': {
            'accessToken': userid,
            'refreshToken': userid,
            'accessExpire': expire,
            'refreshExpire': expire,
            'accessUUID': 'accessUUID',
            'refreshUUID': 'refreshUUID',
        }
    }


91 92
@app.get('/api/sessions/chat/')
def get_sessions(timestamp: int = Query(None, alias="_"),token: str = Header(None)):
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    if not token:
        return {
            'code': 404,
            'data': '验证失败'
        }
    crud = CRUD(_db=c_db)
    chat_list = crud.get_chat_list_userid(token)
    chat_list_str = []
    for chat in chat_list:
        chat_list_str.append(str(chat[0]))
    return {
        'code': 200,
        'data': chat_list_str
    }


109 110
@app.get('/api/session/{session_id}')
def get_history_by_session_id(session_id:str,token: str = Header(None)):
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
    if not token:
        return {
            'code': 404,
            'data': '验证失败'
        }
    crud = CRUD(_db=c_db)
    history = crud.get_history(session_id)
    history_json = []
    for h in history:
        j ={}
        j["TurnID"] = h[0]
        j["Question"] = h[1]
        j["Answer"] = h[2]
        j["IsLast"] = h[3]
        j["SimilarDocuments"] =[]
        history_json.append(j)
    history_str = json.dumps(history_json)
    return {
        'code': 200,
        'data': history_str
    }

133 134
@app.delete('/api/session/{session_id}')
def delete_session_by_session_id(session_id:str,token: str = Header(None)):
135 136 137 138 139 140 141 142 143 144 145 146 147
    if not token:
        return {
            'code': 404,
            'data': '验证失败'
        }
    crud = CRUD(_db=c_db)
    crud.delete_chat(session_id)
    return {
        'code': 200,
        'data': 'success'
    }


148 149
@app.post('/api/general/chat')
def question(chat_request: ChatRequest, token: str = Header(None)):
150 151 152 153 154
    if not token:
        return {
            'code': 404,
            'data': '验证失败'
        }
155 156
    session_id =  chat_request.sessionID
    question = chat_request.question
157 158
    crud = CRUD(_db=c_db)
    history = []
159
    if session_id !="":
160
        history = crud.get_last_history(str(session_id))
161
    # answer = my_chat.chat(question)
162 163 164 165
    prompt = ""
    for h in history:
        prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
    answer, docs = my_chat.chat_with_history(question,history=prompt, with_similarity=True)
166 167 168 169 170 171
    docs_json = []
    for d in docs:
        j ={}
        j["page_content"] = d.page_content
        j["from_file"] = d.metadata["filename"]
        docs_json.append(j)
172
    # answer = "test Answer"
173
    if  session_id =="":
174 175 176 177 178 179 180 181 182 183 184
        session_id = crud.create_chat(token, '\t\t', '0')
        crud.insert_turn_qa(session_id, question, answer, 0, 1)
    else:
        last_turn_id = crud.get_last_turn_num(str(session_id))
        crud.insert_turn_qa(session_id, question, answer, last_turn_id+1, 1)

    return {
        'code': 200,
        'data': {
            'question': question,
            'answer': answer,
185 186
            'sessionID': session_id,
            'similarity': docs_json
187 188 189 190
        }
    }


191 192 193

@app.post('/api/general/regenerate')
def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)):
194 195 196 197 198
    if not token:
        return {
            'code': 404,
            'data': '验证失败'
        }
199 200
    session_id = chat_request.sessionID
    question = chat_request.question
201 202
    crud = CRUD(_db=c_db)
    last_turn_id = crud.get_last_turn_num(str(session_id))
203
    history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id)
204 205 206 207
    prompt = ""
    for h in history:
        prompt += "问:{}\n答:{}\n\n".format(h[0], h[1])
    answer, docs = my_chat.chat_with_history(question, history=prompt, with_similarity=True)
208 209 210 211 212 213 214 215
    docs_json = []
    for d in docs:
        j = {}
        j["page_content"] = d.page_content
        j["from_file"] = d.metadata["filename"]
        docs_json.append(j)
    # answer = "reGenerate Answer"

216 217
    crud.update_turn_last(str(session_id), last_turn_id )
    crud.insert_turn_qa(session_id, question, answer, last_turn_id , 1)
218

219 220 221 222 223
    return {
        'code': 200,
        'data': {
            'question': question,
            'answer': answer,
224 225
            'sessionID': session_id,
            'similarity': docs_json
226 227 228 229
        }
    }

if __name__ == "__main__":
230
    uvicorn.run(app, host='0.0.0.0', port=8088)