Commit 9ad4e078 by 陈正乐

重新问答接口实现

parent cb377359
......@@ -13,10 +13,10 @@ from langchain.prompts import PromptTemplate
from src.controller.request import (
RegisterRequest,
LoginRequest,
ChatCreateRequest,
ChatQaRequest,
ChatDetailRequest,
ChatDeleteRequest
ChatDeleteRequest,
ChatReQA
)
from src.config.consts import (
CHAT_DB_USER,
......@@ -84,8 +84,7 @@ async def login(request: LoginRequest):
@app.post("/lae/chat/create", response_model=Response)
async def create(request: ChatCreateRequest, token: str = Header(None)):
print(request)
async def create(token: str = Header(None)):
print(token)
user_id = token.replace('*', '')
crud = CRUD(_db=c_db)
......@@ -99,6 +98,7 @@ async def create(request: ChatCreateRequest, token: str = Header(None)):
}
return ReturnData(200, '会话创建成功', dict(data)).get_data()
@app.post("/lae/chat/delete", response_model=Response)
async def delete(request: ChatDeleteRequest, token: str = Header(None)):
print(request)
......@@ -116,13 +116,14 @@ async def delete(request: ChatDeleteRequest, token: str = Header(None)):
}
return ReturnData(200, '会话删除成功', dict(data)).get_data()
@app.post("/lae/chat/qa", response_model=Response)
async def qa(request: ChatQaRequest, token: str = Header(None)):
print(request)
print(token)
user_id = token.replace('*', '')
chat_id = request.chat_id
question = request.query
question = request.question
crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
......@@ -144,7 +145,8 @@ async def qa(request: ChatQaRequest, token: str = Header(None)):
data = {
"answer": answer
}
return ReturnData(200, '会话创建成功', dict(data)).get_data()
return ReturnData(200, '模型问答成功', dict(data)).get_data()
@app.get("/lae/chat/detail", response_model=Response)
async def detail(request: ChatDetailRequest, token: str = Header(None)):
......@@ -175,5 +177,35 @@ async def clist(token: str = Header(None)):
return ReturnData(200, "会话列表获取成功", dict(data)).get_data()
@app.post("/lae/chat/reqa", response_model=Response)
async def reqa(request: ChatReQA, token: str = Header(None)):
print(request)
print(token)
chat_id = request.chat_id
user_id = token.replace('*', '')
crud = CRUD(_db=c_db)
if not crud.chat_exist_chatid_userid(_chat_id=chat_id, _user_id=user_id):
return ReturnData(40000, "该会话不存在(用户无法访问非本人创建的对话)", {}).get_data()
question = crud.get_last_question(_chat_id=chat_id)
vecstore_faiss = VectorStore_FAISS(
embedding_model_name=EMBEEDING_MODEL_PATH,
store_path=FAISS_STORE_PATH,
index_name=INDEX_NAME,
info={"port": VEC_DB_PORT, "host": VEC_DB_HOST, "dbname": VEC_DB_DBNAME, "username": VEC_DB_USER,
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
my_chat = QA(PromptTemplate(input_variables=["context", "question"], template=prompt1), base_llm,
{"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id=chat_id,
_faiss_db=vecstore_faiss)
answer = my_chat.chat(question)
my_chat.update_history()
data = {
"answer": answer
}
return ReturnData(200, '模型重新问答成功', dict(data)).get_data()
if __name__ == "__main__":
uvicorn.run(app, host='localhost', port=8889)
......@@ -11,13 +11,9 @@ class RegisterRequest(BaseModel):
password: str
class ChatCreateRequest(BaseModel):
id: str
class ChatQaRequest(BaseModel):
chat_id: str = None
query: str
question: str
class ChatDetailRequest(BaseModel):
......@@ -26,3 +22,8 @@ class ChatDetailRequest(BaseModel):
class ChatDeleteRequest(BaseModel):
chat_id: str
class ChatReQA(BaseModel):
chat_id: str
query: str
......@@ -2,8 +2,7 @@ import os
import requests
from typing import Dict, Optional, List, Any, Mapping, Iterator
from pydantic import root_validator
from langchain.llms.base import BaseLLM, LLM
from langchain.llms.base import LLM
from langchain.cache import InMemoryCache
from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks, AsyncCallbackManagerForLLMRun
import qianfan
......
......@@ -132,7 +132,6 @@ class CRUD:
self.db.execute_args(query, (_user_id,))
return self.db.fetchall()
def get_chatinfo_from_chatid(self, _chat_id):
query = f'SELECT info FROM chat WHERE chat_id = (%s)'
self.db.execute_args(query, (_chat_id,))
......@@ -141,3 +140,8 @@ class CRUD:
def delete_chat(self, _chat_id):
query = f'UPDATE chat SET deleted = 1 WHERE chat_id = (%s)'
self.db.execute_args(query, (_chat_id,))
def get_last_question(self, _chat_id):
query = f'SELECT question FROM turn_qa WHERE chat_id = (%s) AND turn_number = 1'
self.db.execute_args(query, (_chat_id,))
return self.db.fetchone()[0]
......@@ -60,7 +60,7 @@ class QA:
return GetSimilarity(_question=_aquestion, _faiss_db=self.faiss_db).get_similarity_doc()
# 一次性直接给出所有的答案
async def chat(self, _question):
def chat(self, _question):
self.cur_oquestion = _question
self.cur_similarity = self.get_similarity(_aquestion=self.cur_oquestion)
self.cur_question = self.prompt.format(**{k: v for k, v in zip(self.prompt_kwargs, (self.cur_similarity, self.cur_oquestion))})
......
......@@ -48,9 +48,9 @@ def main():
"password": VEC_DB_PASSWORD},
show_number=SIMILARITY_SHOW_NUMBER,
reset=False)
# base_llm = ChatERNIESerLLM(
# chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088')
base_llm = ChatERNIESerLLM(
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
# base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088')
my_chat = QA(PROMPT1, base_llm, {"temperature": 0.9}, ['context', 'question'], _db=c_db, _chat_id='2',
_faiss_db=vecstore_faiss)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment