Commit 9ad4e078 by 陈正乐

重新问答接口实现

parent cb377359
...@@ -13,10 +13,10 @@ from langchain.prompts import PromptTemplate ...@@ -13,10 +13,10 @@ from langchain.prompts import PromptTemplate
from src.controller.request import ( from src.controller.request import (
RegisterRequest, RegisterRequest,
LoginRequest, LoginRequest,
ChatCreateRequest,
ChatQaRequest, ChatQaRequest,
ChatDetailRequest, ChatDetailRequest,
ChatDeleteRequest ChatDeleteRequest,
ChatReQA
) )
from src.config.consts import ( from src.config.consts import (
CHAT_DB_USER, CHAT_DB_USER,
...@@ -84,8 +84,7 @@ async def login(request: LoginRequest): ...@@ -84,8 +84,7 @@ async def login(request: LoginRequest):
@app.post("/lae/chat/create", response_model=Response) @app.post("/lae/chat/create", response_model=Response)
async def create(request: ChatCreateRequest, token: str = Header(None)): async def create(token: str = Header(None)):
print(request)
print(token) print(token)
user_id = token.replace('*', '') user_id = token.replace('*', '')
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
...@@ -99,6 +98,7 @@ async def create(request: ChatCreateRequest, token: str = Header(None)): ...@@ -99,6 +98,7 @@ async def create(request: ChatCreateRequest, token: str = Header(None)):
} }
return ReturnData(200, '会话创建成功', dict(data)).get_data() return ReturnData(200, '会话创建成功', dict(data)).get_data()
@app.post("/lae/chat/delete", response_model=Response) @app.post("/lae/chat/delete", response_model=Response)
async def delete(request: ChatDeleteRequest, token: str = Header(None)): async def delete(request: ChatDeleteRequest, token: str = Header(None)):
print(request) print(request)
...@@ -116,13 +116,14 @@ async def delete(request: ChatDeleteRequest, token: str = Header(None)): ...@@ -116,13 +116,14 @@ async def delete(request: ChatDeleteRequest, token: str = Header(None)):
} }
return ReturnData(200, '会话删除成功', dict(data)).get_data() return ReturnData(200, '会话删除成功', dict(data)).get_data()
@app.post("/lae/chat/qa", response_model=Response) @app.post("/lae/chat/qa", response_model=Response)
async def qa(request: ChatQaRequest, token: str = Header(None)): async def qa(request: ChatQaRequest, token: str = Header(None)):
print(request) print(request)
print(token) print(token)
user_id = token.replace('*', '') user_id = token.replace('*', '')
chat_id = request.chat_id chat_id = request.chat_id
question = request.query question = request.question
crud = CRUD(_db=c_db) crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id): if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data() return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
...@@ -144,7 +145,8 @@ async def qa(request: ChatQaRequest, token: str = Header(None)): ...@@ -144,7 +145,8 @@ async def qa(request: ChatQaRequest, token: str = Header(None)):
data = { data = {
"answer": answer "answer": answer
} }
return ReturnData(200, '会话创建成功', dict(data)).get_data() return ReturnData(200, '模型问答成功', dict(data)).get_data()
@app.get("/lae/chat/detail", response_model=Response) @app.get("/lae/chat/detail", response_model=Response)
async def detail(request: ChatDetailRequest, token: str = Header(None)): async def detail(request: ChatDetailRequest, token: str = Header(None)):
...@@ -175,5 +177,35 @@ async def clist(token: str = Header(None)): ...@@ -175,5 +177,35 @@ async def clist(token: str = Header(None)):
return ReturnData(200, "会话列表获取成功", dict(data)).get_data() return ReturnData(200, "会话列表获取成功", dict(data)).get_data()
@app.post("/lae/chat/reqa", response_model=Response)
async def reqa(request: ChatReQA, token: str = Header(None)):
print(request)
print(token)
chat_id = request.chat_id
user_id = token.replace('*', '')
crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
question = crud.get_last_question(_chat_id=chat_id)
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id,
_faiss_db=vecstore_faiss)
answer = my_chat.chat(question)
my_chat.update_history()
data = {
"answer": answer
}
return ReturnData(200, '模型重新问答成功', dict(data)).get_data()
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host='localhost', port=8889) uvicorn.run(app, host='localhost', port=8889)
...@@ -11,13 +11,9 @@ class RegisterRequest(BaseModel): ...@@ -11,13 +11,9 @@ class RegisterRequest(BaseModel):
password: str password: str
class ChatCreateRequest(BaseModel):
id: str
class ChatQaRequest(BaseModel): class ChatQaRequest(BaseModel):
chat_id: str = None chat_id: str = None
query: str question: str
class ChatDetailRequest(BaseModel): class ChatDetailRequest(BaseModel):
...@@ -26,3 +22,8 @@ class ChatDetailRequest(BaseModel): ...@@ -26,3 +22,8 @@ class ChatDetailRequest(BaseModel):
class ChatDeleteRequest(BaseModel): class ChatDeleteRequest(BaseModel):
chat_id: str chat_id: str
class ChatReQA(BaseModel):
chat_id: str
query: str
...@@ -2,8 +2,7 @@ import os ...@@ -2,8 +2,7 @@ import os
import requests import requests
from typing import Dict, Optional, List, Any, Mapping, Iterator from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator from pydantic import root_validator
from langchain.llms.base import LLM
from langchain.llms.base import BaseLLM, LLM
from langchain.cache import InMemoryCache from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import qianfan import qianfan
......
...@@ -132,7 +132,6 @@ class CRUD: ...@@ -132,7 +132,6 @@ class CRUD:
self.db.execute_args(query, (_user_id,)) self.db.execute_args(query, (_user_id,))
return self.db.fetchall() return self.db.fetchall()
def get_chatinfo_from_chatid(self, _chat_id): def get_chatinfo_from_chatid(self, _chat_id):
query = f'SELECT info FROM chat WHERE chat_id = (%s)' query = f'SELECT info FROM chat WHERE chat_id = (%s)'
self.db.execute_args(query, (_chat_id,)) self.db.execute_args(query, (_chat_id,))
...@@ -141,3 +140,8 @@ class CRUD: ...@@ -141,3 +140,8 @@ class CRUD:
def delete_chat(self, _chat_id): def delete_chat(self, _chat_id):
query = f'UPDATE chat SET deleted = 1 WHERE chat_id = (%s)' query = f'UPDATE chat SET deleted = 1 WHERE chat_id = (%s)'
self.db.execute_args(query, (_chat_id,)) self.db.execute_args(query, (_chat_id,))
def get_last_question(self, _chat_id):
query = f'SELECT question FROM turn_qa WHERE chat_id = (%s) AND turn_number = 1'
self.db.execute_args(query, (_chat_id,))
return self.db.fetchone()[0]
...@@ -60,7 +60,7 @@ class QA: ...@@ -60,7 +60,7 @@ class QA:
return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc() return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
# 一次性直接给出所有的答案 # 一次性直接给出所有的答案
async def chat(self, _question): def chat(self, _question):
self.cur_oquestion = _question self.cur_oquestion = _question
self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion) self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))}) self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))})
......
...@@ -48,9 +48,9 @@ def main(): ...@@ -48,9 +48,9 @@ def main():
"password": VEC_DB_PASSWORD}, "password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER, show_number=SIMILARITY_SHOW_NUMBER,
reset=False) reset=False)
# base_llm = ChatERNIESerLLM( base_llm = ChatERNIESerLLM(
# chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu")) chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088') # base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088')
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2', my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2',
_faiss_db=vecstore_faiss) _faiss_db=vecstore_faiss)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment