Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
9ad4e078
Commit
9ad4e078
authored
May 11, 2024
by
陈正乐
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
重新问答接口实现
parent
cb377359
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
20 deletions
+55
-20
api.py
src/controller/api.py
+38
-6
request.py
src/controller/request.py
+6
-5
ernie_with_sdk.py
src/llm/ernie_with_sdk.py
+1
-2
crud.py
src/pgdb/chat/crud.py
+6
-3
qa.py
src/server/qa.py
+1
-1
gradio_test.py
test/gradio_test.py
+3
-3
No files found.
src/controller/api.py
View file @
9ad4e078
...
...
@@ -13,10 +13,10 @@ from langchain.prompts import PromptTemplate
from
src.controller.request
import
(
RegisterRequest
,
LoginRequest
,
ChatCreateRequest
,
ChatQaRequest
,
ChatDetailRequest
,
ChatDeleteRequest
ChatDeleteRequest
,
ChatReQA
)
from
src.config.consts
import
(
CHAT_DB_USER
,
...
...
@@ -84,8 +84,7 @@ async def login(request: LoginRequest):
@app.post
(
"/lae/chat/create"
,
response_model
=
Response
)
async
def
create
(
request
:
ChatCreateRequest
,
token
:
str
=
Header
(
None
)):
print
(
request
)
async
def
create
(
token
:
str
=
Header
(
None
)):
print
(
token
)
user_id
=
token
.
replace
(
'*'
,
''
)
crud
=
CRUD
(
_db
=
c_db
)
...
...
@@ -99,6 +98,7 @@ async def create(request: ChatCreateRequest, token: str = Header(None)):
}
return
ReturnData
(
200
,
'会话创建成功'
,
dict
(
data
))
.
get_data
()
@app.post
(
"/lae/chat/delete"
,
response_model
=
Response
)
async
def
delete
(
request
:
ChatDeleteRequest
,
token
:
str
=
Header
(
None
)):
print
(
request
)
...
...
@@ -116,13 +116,14 @@ async def delete(request: ChatDeleteRequest, token: str = Header(None)):
}
return
ReturnData
(
200
,
'会话删除成功'
,
dict
(
data
))
.
get_data
()
@app.post
(
"/lae/chat/qa"
,
response_model
=
Response
)
async
def
qa
(
request
:
ChatQaRequest
,
token
:
str
=
Header
(
None
)):
print
(
request
)
print
(
token
)
user_id
=
token
.
replace
(
'*'
,
''
)
chat_id
=
request
.
chat_id
question
=
request
.
que
ry
question
=
request
.
que
stion
crud
=
CRUD
(
_db
=
c_db
)
if
not
crud
.
chat_exist_chatid_userid
(
_chat_id
=
chat_id
,
_user_id
=
user_id
):
return
ReturnData
(
40000
,
"该会话不存在(用户无法访问非本人创建的对话)"
,
{})
.
get_data
()
...
...
@@ -144,7 +145,8 @@ async def qa(request: ChatQaRequest, token: str = Header(None)):
data
=
{
"answer"
:
answer
}
return
ReturnData
(
200
,
'会话创建成功'
,
dict
(
data
))
.
get_data
()
return
ReturnData
(
200
,
'模型问答成功'
,
dict
(
data
))
.
get_data
()
@app.get
(
"/lae/chat/detail"
,
response_model
=
Response
)
async
def
detail
(
request
:
ChatDetailRequest
,
token
:
str
=
Header
(
None
)):
...
...
@@ -175,5 +177,35 @@ async def clist(token: str = Header(None)):
return
ReturnData
(
200
,
"会话列表获取成功"
,
dict
(
data
))
.
get_data
()
@app.post
(
"/lae/chat/reqa"
,
response_model
=
Response
)
async
def
reqa
(
request
:
ChatReQA
,
token
:
str
=
Header
(
None
)):
print
(
request
)
print
(
token
)
chat_id
=
request
.
chat_id
user_id
=
token
.
replace
(
'*'
,
''
)
crud
=
CRUD
(
_db
=
c_db
)
if
not
crud
.
chat_exist_chatid_userid
(
_chat_id
=
chat_id
,
_user_id
=
user_id
):
return
ReturnData
(
40000
,
"该会话不存在(用户无法访问非本人创建的对话)"
,
{})
.
get_data
()
question
=
crud
.
get_last_question
(
_chat_id
=
chat_id
)
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
store_path
=
FAISS_STORE_PATH
,
index_name
=
INDEX_NAME
,
info
=
{
"port"
:
VEC_DB_PORT
,
"host"
:
VEC_DB_HOST
,
"dbname"
:
VEC_DB_DBNAME
,
"username"
:
VEC_DB_USER
,
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
base_llm
=
ChatERNIESerLLM
(
chat_completion
=
ChatCompletion
(
ak
=
"pT7sV1smp4AeDl0LjyZuHBV9"
,
sk
=
"b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"
))
my_chat
=
QA
(
PromptTemplate
(
input_variables
=
[
"context"
,
"question"
],
template
=
prompt1
),
base_llm
,
{
"temperature"
:
0.9
},
[
'context'
,
'question'
],
_db
=
c_db
,
_chat_id
=
chat_id
,
_faiss_db
=
vecstore_faiss
)
answer
=
my_chat
.
chat
(
question
)
my_chat
.
update_history
()
data
=
{
"answer"
:
answer
}
return
ReturnData
(
200
,
'模型重新问答成功'
,
dict
(
data
))
.
get_data
()
if
__name__
==
"__main__"
:
uvicorn
.
run
(
app
,
host
=
'localhost'
,
port
=
8889
)
src/controller/request.py
View file @
9ad4e078
...
...
@@ -11,13 +11,9 @@ class RegisterRequest(BaseModel):
password
:
str
class
ChatCreateRequest
(
BaseModel
):
id
:
str
class
ChatQaRequest
(
BaseModel
):
chat_id
:
str
=
None
que
ry
:
str
que
stion
:
str
class
ChatDetailRequest
(
BaseModel
):
...
...
@@ -26,3 +22,8 @@ class ChatDetailRequest(BaseModel):
class
ChatDeleteRequest
(
BaseModel
):
chat_id
:
str
class
ChatReQA
(
BaseModel
):
chat_id
:
str
query
:
str
src/llm/ernie_with_sdk.py
View file @
9ad4e078
...
...
@@ -2,8 +2,7 @@ import os
import
requests
from
typing
import
Dict
,
Optional
,
List
,
Any
,
Mapping
,
Iterator
from
pydantic
import
root_validator
from
langchain.llms.base
import
BaseLLM
,
LLM
from
langchain.llms.base
import
LLM
from
langchain.cache
import
InMemoryCache
from
langchain.callbacks.manager
import
CallbackManagerForLLMRun
,
Callbacks
,
AsyncCallbackManagerForLLMRun
import
qianfan
...
...
src/pgdb/chat/crud.py
View file @
9ad4e078
...
...
@@ -132,7 +132,6 @@ class CRUD:
self
.
db
.
execute_args
(
query
,
(
_user_id
,))
return
self
.
db
.
fetchall
()
def
get_chatinfo_from_chatid
(
self
,
_chat_id
):
query
=
f
'SELECT info FROM chat WHERE chat_id = (
%
s)'
self
.
db
.
execute_args
(
query
,
(
_chat_id
,))
...
...
@@ -140,4 +139,9 @@ class CRUD:
def
delete_chat
(
self
,
_chat_id
):
query
=
f
'UPDATE chat SET deleted = 1 WHERE chat_id = (
%
s)'
self
.
db
.
execute_args
(
query
,
(
_chat_id
,))
\ No newline at end of file
self
.
db
.
execute_args
(
query
,
(
_chat_id
,))
def
get_last_question
(
self
,
_chat_id
):
query
=
f
'SELECT question FROM turn_qa WHERE chat_id = (
%
s) AND turn_number = 1'
self
.
db
.
execute_args
(
query
,
(
_chat_id
,))
return
self
.
db
.
fetchone
()[
0
]
src/server/qa.py
View file @
9ad4e078
...
...
@@ -60,7 +60,7 @@ class QA:
return
GetSimilarity
(
_question
=
_aquestion
,
_faiss_db
=
self
.
faiss_db
)
.
get_similarity_doc
()
# 一次性直接给出所有的答案
async
def
chat
(
self
,
_question
):
def
chat
(
self
,
_question
):
self
.
cur_oquestion
=
_question
self
.
cur_similarity
=
self
.
get_similarity
(
_aquestion
=
self
.
cur_oquestion
)
self
.
cur_question
=
self
.
prompt
.
format
(
**
{
k
:
v
for
k
,
v
in
zip
(
self
.
prompt_kwargs
,
(
self
.
cur_similarity
,
self
.
cur_oquestion
))})
...
...
test/gradio_test.py
View file @
9ad4e078
...
...
@@ -48,9 +48,9 @@ def main():
"password"
:
VEC_DB_PASSWORD
},
show_number
=
SIMILARITY_SHOW_NUMBER
,
reset
=
False
)
#
base_llm = ChatERNIESerLLM(
#
chat_completion=ChatCompletion(ak="pT7sV1smp4AeDl0LjyZuHBV9", sk="b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"))
base_llm
=
ChatGLMSerLLM
(
url
=
'http://192.168.22.106:8088'
)
base_llm
=
ChatERNIESerLLM
(
chat_completion
=
ChatCompletion
(
ak
=
"pT7sV1smp4AeDl0LjyZuHBV9"
,
sk
=
"b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"
))
#
base_llm = ChatGLMSerLLM(url='http://192.168.22.106:8088')
my_chat
=
QA
(
PROMPT1
,
base_llm
,
{
"temperature"
:
0.9
},
[
'context'
,
'question'
],
_db
=
c_db
,
_chat_id
=
'2'
,
_faiss_db
=
vecstore_faiss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment