import sys sys.path.append('../') from fastapi import FastAPI, Header,Query from fastapi.middleware.cors import CORSMiddleware from datetime import datetime,timedelta from src.pgdb.chat.c_db import UPostgresDB import uvicorn import json from src.pgdb.chat.crud import CRUD from src.pgdb.knowledge.similarity import VectorStore_FAISS from src.pgdb.knowledge.k_db import PostgresDB from src.pgdb.knowledge.txt_doc_table import TxtDoc from langchain_openai import ChatOpenAI from langchain_core.documents import Document from src.server.rag_query import RagQuery from src.controller.request import ( PhoneLoginRequest, ChatRequest, ReGenerateRequest ) from src.config.consts import ( CHAT_DB_USER, CHAT_DB_HOST, CHAT_DB_PORT, CHAT_DB_DBNAME, CHAT_DB_PASSWORD, EMBEEDING_MODEL_PATH, FAISS_STORE_PATH, INDEX_NAME, VEC_DB_HOST, VEC_DB_PASSWORD, VEC_DB_PORT, VEC_DB_USER, VEC_DB_DBNAME, SIMILARITY_SHOW_NUMBER, ) app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # 允许所有域访问,也可以指定特定域名 allow_credentials=True, allow_methods=["*"], # 允许所有HTTP方法 allow_headers=["*"], # 允许所有HTTP头 ) c_db = UPostgresDB(host=CHAT_DB_HOST, database=CHAT_DB_DBNAME, user=CHAT_DB_USER, password=CHAT_DB_PASSWORD, port=CHAT_DB_PORT, ) c_db.connect() k_db = PostgresDB(host=VEC_DB_HOST, database=VEC_DB_DBNAME, user=VEC_DB_USER, password=VEC_DB_PASSWORD, port=VEC_DB_PORT) k_db.connect() vecstore_faiss = VectorStore_FAISS( embedding_model_name=EMBEEDING_MODEL_PATH, store_path=FAISS_STORE_PATH, index_name=INDEX_NAME, info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER, "password": VEC_DB_PASSWORD}, show_number=SIMILARITY_SHOW_NUMBER, reset=False) base_llm = ChatOpenAI( openai_api_key='xxxxxxxxxxxxx', openai_api_base='http://192.168.10.14:8000/v1', model_name='Qwen2-7B', verbose=True ) rag_query = RagQuery(base_llm=base_llm,_faiss_db=vecstore_faiss,_db=TxtDoc(k_db)) @app.post('/api/login') def login(phone_request: PhoneLoginRequest): phone = phone_request.phone crud = CRUD(_db=c_db) user = crud.user_exist_account(phone) if not user: crud.insert_c_user(phone,"123456") user = crud.user_exist_account(phone) userid = user[0] expire = (datetime.now() + timedelta(days=1)).strftime('%Y-%m-%d %H:%M:%S') return{ 'code': 200, 'data': { 'accessToken': userid, 'refreshToken': userid, 'accessExpire': expire, 'refreshExpire': expire, 'accessUUID': 'accessUUID', 'refreshUUID': 'refreshUUID', } } @app.get('/api/sessions/chat/') def get_sessions(token: str = Header(None)): if not token: return { 'code': 404, 'data': '验证失败' } crud = CRUD(_db=c_db) chat_list = crud.get_chat_list_userid(token) chat_list_str = [] for chat in chat_list: chat_list_str.append(str(chat[0])) return { 'code': 200, 'data': chat_list_str } @app.get('/api/session/{session_id}') def get_history_by_session_id(session_id:str,token: str = Header(None)): if not token: return { 'code': 404, 'data': '验证失败' } crud = CRUD(_db=c_db) history = crud.get_history(session_id) history_json = [] for h in history: j ={} j["TurnID"] = h[0] j["Question"] = h[1] j["Answer"] = h[2] j["IsLast"] = h[3] j["SimilarDocuments"] = get_similarity_doc(h[4]) history_json.append(j) history_str = json.dumps(history_json) return { 'code': 200, 'data': history_str } @app.delete('/api/session/{session_id}') def delete_session_by_session_id(session_id:str,token: str = Header(None)): if not token: return { 'code': 404, 'data': '验证失败' } crud = CRUD(_db=c_db) crud.delete_chat(session_id) return { 'code': 200, 'data': 'success' } @app.post('/api/general/chat') def question(chat_request: ChatRequest, token: str = Header(None)): if not token: return { 'code': 404, 'data': '验证失败' } session_id = chat_request.sessionID question = chat_request.question crud = CRUD(_db=c_db) history = [] if session_id !="": history = crud.get_last_history(str(session_id)) prompt = "" for h in history: prompt += "Q: {}\nA:{}\n".format(h[0], h[1]) res = rag_query.query(question=question,history=prompt) answer = res["answer"] docs = res["docs"] docs_json = json.loads(docs, strict=False) print(len(docs_json)) doc_hash = [] for d in docs_json: if "hash" in d: doc_hash.append(d["hash"]) if len(doc_hash)>0: hash_str = ",".join(doc_hash) else: hash_str = "" # answer = "test Answer" if session_id =="": session_id = crud.create_chat(token, '\t\t', '0') crud.insert_turn_qa(session_id, question, answer, 0, 1, hash_str) else: last_turn_id = crud.get_last_turn_num(str(session_id)) crud.insert_turn_qa(session_id, question, answer, last_turn_id+1, 1, hash_str) return { 'code': 200, 'data': { 'question': question, 'answer': answer, 'sessionID': session_id, 'similarity': docs_json } } @app.post('/api/general/regenerate') def re_generate(chat_request: ReGenerateRequest, token: str = Header(None)): if not token: return { 'code': 404, 'data': '验证失败' } session_id = chat_request.sessionID question = chat_request.question crud = CRUD(_db=c_db) last_turn_id = crud.get_last_turn_num(str(session_id)) history = crud.get_last_history_before_turn_id(str(session_id),last_turn_id) prompt = "" for h in history: prompt += "Q: {}\nA:{}\n".format(h[0], h[1]) res = rag_query.query(question=question, history=prompt) answer = res["answer"] docs = res["docs"] docs_json = json.loads(docs, strict=False) doc_hash = [] for d in docs_json: if "hash" in d: doc_hash.append(d["hash"]) if len(doc_hash)>0: hash_str = ",".join(doc_hash) else: hash_str = "" # answer = "reGenerate Answer" crud.update_turn_last(str(session_id), last_turn_id ) crud.insert_turn_qa(session_id, question, answer, last_turn_id, 1, hash_str) return { 'code': 200, 'data': { 'question': question, 'answer': answer, 'sessionID': session_id, 'similarity': docs_json } } def get_similarity_doc(similarity_doc_hash: str): if similarity_doc_hash: hashs = similarity_doc_hash.split(",") if not similarity_doc_hash or len(hashs) == 0: return [] docs = [] txt_doc = TxtDoc(k_db) for h in hashs: doc = txt_doc.search(h) if doc is None: continue d = Document(page_content=doc[0],metadata=json.loads(doc[1])) docs.append(d) return docs_to_json(docs) def docs_to_json(docs): docs_json = [] for d in docs: j = {} j["page_content"] = d.page_content j["from_file"] = d.metadata["filename"] j["page_number"] = 0 docs_json.append(j) return docs_json if __name__ == "__main__": uvicorn.run(app, host='0.0.0.0', port=8088)