Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
9c21eae7
Commit
9c21eae7
authored
Jul 16, 2024
by
文靖昊
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
添加混排,rag工具增加行政区名称参数
parent
cc83b288
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
64 additions
and
43 deletions
+64
-43
rag_agent.py
src/agent/rag_agent.py
+32
-22
consts.py
src/config/consts.py
+1
-1
txt_doc_table.py
src/pgdb/knowledge/txt_doc_table.py
+1
-1
get_similarity.py
src/server/get_similarity.py
+25
-14
rerank.py
src/server/rerank.py
+5
-5
No files found.
src/agent/rag_agent.py
View file @
9c21eae7
...
@@ -16,6 +16,8 @@ from src.pgdb.knowledge.txt_doc_table import TxtDoc
...
@@ -16,6 +16,8 @@ from src.pgdb.knowledge.txt_doc_table import TxtDoc
from
langchain_core.documents
import
Document
from
langchain_core.documents
import
Document
import
json
import
json
from
src.agent.tool_divisions
import
AdministrativeDivision
from
src.agent.tool_divisions
import
AdministrativeDivision
from
src.llm.ernie_with_sdk
import
ChatERNIESerLLM
from
qianfan
import
ChatCompletion
from
src.config.consts
import
(
from
src.config.consts
import
(
RERANK_MODEL_PATH
,
RERANK_MODEL_PATH
,
CHAT_DB_USER
,
CHAT_DB_USER
,
...
@@ -1557,6 +1559,7 @@ def jieba_split(text: str) -> str:
...
@@ -1557,6 +1559,7 @@ def jieba_split(text: str) -> str:
class
IssuanceArgs
(
BaseModel
):
class
IssuanceArgs
(
BaseModel
):
question
:
str
=
Field
(
description
=
"对话问题"
)
question
:
str
=
Field
(
description
=
"对话问题"
)
history
:
str
=
Field
(
description
=
"历史对话记录"
)
history
:
str
=
Field
(
description
=
"历史对话记录"
)
location
:
list
=
Field
(
description
=
"行政区名称"
)
class
RAGQuery
(
BaseTool
):
class
RAGQuery
(
BaseTool
):
...
@@ -1582,22 +1585,24 @@ class RAGQuery(BaseTool):
...
@@ -1582,22 +1585,24 @@ class RAGQuery(BaseTool):
def
get_similarity_with_ext_origin
(
self
,
_ext
):
def
get_similarity_with_ext_origin
(
self
,
_ext
):
return
GetSimilarityWithExt
(
_question
=
_ext
,
_faiss_db
=
self
.
faiss_db
)
return
GetSimilarityWithExt
(
_question
=
_ext
,
_faiss_db
=
self
.
faiss_db
)
def
_run
(
self
,
question
:
str
,
history
:
str
)
->
str
:
def
_run
(
self
,
question
:
str
,
history
:
str
,
location
:
list
)
->
str
:
split_str
=
jieba_split
(
ques
tion
)
print
(
loca
tion
)
split_list
=
[]
# split_str = jieba_split(question)
for
l
in
split_str
:
# split_list = []
split_list
.
append
(
l
)
# for l in split_str:
answer
=
self
.
db
.
find_like_doc
(
split_list
)
# split_list.append(l
)
print
(
answer
)
answer
=
self
.
db
.
find_like_doc
(
location
)
split_docs
=
[]
split_docs
=
[]
for
a
in
answer
:
for
a
in
answer
:
d
=
Document
(
page_content
=
a
[
0
],
metadata
=
json
.
loads
(
a
[
1
]))
d
=
Document
(
page_content
=
a
[
0
],
metadata
=
json
.
loads
(
a
[
1
]))
split_docs
.
append
(
d
)
split_docs
.
append
(
d
)
print
(
split_docs
)
print
(
len
(
split_docs
))
if
len
(
split_docs
)
>
100
:
split_docs
=
split_docs
[:
100
]
result
=
self
.
rerank
.
extend_query_with_str
(
question
,
history
)
result
=
self
.
rerank
.
extend_query_with_str
(
question
,
history
)
matches
=
re
.
findall
(
r'"([^"]+)"'
,
result
.
conten
t
)
print
(
resul
t
)
if
len
(
matches
)
>
3
:
matches
=
re
.
findall
(
r'"([^"]+)"'
,
result
)
matches
=
matches
[:
3
]
print
(
matches
)
print
(
matches
)
similarity
=
self
.
get_similarity_with_ext_origin
(
matches
)
similarity
=
self
.
get_similarity_with_ext_origin
(
matches
)
# cur_similarity = similarity.get_rerank(self.rerank_model)
# cur_similarity = similarity.get_rerank(self.rerank_model)
...
@@ -1606,13 +1611,15 @@ class RAGQuery(BaseTool):
...
@@ -1606,13 +1611,15 @@ class RAGQuery(BaseTool):
return
cur_question
return
cur_question
base_llm
=
ChatOpenAI
(
# base_llm = ChatOpenAI(
openai_api_key
=
'xxxxxxxxxxxxx'
,
# openai_api_key='xxxxxxxxxxxxx',
openai_api_base
=
'http://192.168.10.14:8000/v1'
,
# openai_api_base='http://192.168.10.14:8000/v1',
model_name
=
'Qwen2-7B'
,
# model_name='Qwen2-7B',
verbose
=
True
,
# verbose=True,
temperature
=
0
# temperature=0
)
# )
base_llm
=
ChatERNIESerLLM
(
chat_completion
=
ChatCompletion
(
ak
=
"pT7sV1smp4AeDl0LjyZuHBV9"
,
sk
=
"b3N0ibo1IKTLZlSs7weZc8jdR0oHjyMu"
))
vecstore_faiss
=
VectorStore_FAISS
(
vecstore_faiss
=
VectorStore_FAISS
(
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
embedding_model_name
=
EMBEEDING_MODEL_PATH
,
...
@@ -1651,15 +1658,18 @@ agent = create_structured_chat_agent(llm=base_llm, tools=tools, prompt=prompt)
...
@@ -1651,15 +1658,18 @@ agent = create_structured_chat_agent(llm=base_llm, tools=tools, prompt=prompt)
agent_executor
=
AgentExecutor
(
agent
=
agent
,
tools
=
tools
,
verbose
=
True
,
handle_parsing_errors
=
True
)
agent_executor
=
AgentExecutor
(
agent
=
agent
,
tools
=
tools
,
verbose
=
True
,
handle_parsing_errors
=
True
)
history
=
[]
history
=
[]
h1
=
[]
h1
=
[]
h1
.
append
(
"大通县年降雨量"
)
h1
.
append
(
"攸县年降雨量"
)
h1
.
append
(
"大通县年雨量平均为30ml"
)
h1
.
append
(
"攸县年雨量平均为30ml"
)
history
.
append
(
h1
)
h1
=
[]
h1
.
append
(
"长沙县年降雨量"
)
h1
.
append
(
"长沙县年雨量平均为50ml"
)
history
.
append
(
h1
)
history
.
append
(
h1
)
prompt
=
""
prompt
=
""
for
h
in
history
:
for
h
in
history
:
prompt
+=
"问:{}
\n
答:{}
\n\n
"
.
format
(
h
[
0
],
h
[
1
])
prompt
+=
"问:{}
\n
答:{}
\n\n
"
.
format
(
h
[
0
],
h
[
1
])
print
(
prompt
)
print
(
prompt
)
res
=
agent_executor
.
invoke
({
"input"
:
"以下历史对话记录: "
+
prompt
+
"以下是问题:"
+
"
西宁市年平均降雨量
"
})
res
=
agent_executor
.
invoke
({
"input"
:
"以下历史对话记录: "
+
prompt
+
"以下是问题:"
+
"
攸县、长沙县、大通县和化隆县谁的年平均降雨量大
"
})
print
(
"====== result: ======"
)
print
(
"====== result: ======"
)
print
(
res
)
print
(
res
)
src/config/consts.py
View file @
9c21eae7
...
@@ -35,7 +35,7 @@ LLM_SERVER_URL = '192.168.10.102:8002'
...
@@ -35,7 +35,7 @@ LLM_SERVER_URL = '192.168.10.102:8002'
# =============================
# =============================
# FAISS相似性查找配置
# FAISS相似性查找配置
# =============================
# =============================
SIMILARITY_SHOW_NUMBER
=
1
0
SIMILARITY_SHOW_NUMBER
=
3
0
SIMILARITY_THRESHOLD
=
0.8
SIMILARITY_THRESHOLD
=
0.8
# =============================
# =============================
...
...
src/pgdb/knowledge/txt_doc_table.py
View file @
9c21eae7
...
@@ -74,7 +74,7 @@ class TxtDoc:
...
@@ -74,7 +74,7 @@ class TxtDoc:
print
(
item
)
print
(
item
)
query
=
"select text,matadate FROM txt_doc WHERE text like '
%
"
+
i0
+
"
%
' "
query
=
"select text,matadate FROM txt_doc WHERE text like '
%
"
+
i0
+
"
%
' "
for
i
in
item
:
for
i
in
item
:
query
+=
"
and
text like '
%
"
+
i
+
"
%
' "
query
+=
"
or
text like '
%
"
+
i
+
"
%
' "
print
(
query
)
print
(
query
)
self
.
db
.
execute
(
query
)
self
.
db
.
execute
(
query
)
answer
=
self
.
db
.
fetchall
()
answer
=
self
.
db
.
fetchall
()
...
...
src/server/get_similarity.py
View file @
9c21eae7
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.pgdb.knowledge.similarity
import
VectorStore_FAISS
from
src.config.prompts
import
PROMPT_QUERY_EXTEND
,
PROMPT_QA_EXTEND_QUESTION
from
src.config.prompts
import
PROMPT_QUERY_EXTEND
,
PROMPT_QA_EXTEND_QUESTION
from
.rerank
import
BgeRerank
from
src.server.rerank
import
BgeRerank
,
reciprocal_rank_fusion
from
langchain_core.prompts
import
PromptTemplate
from
langchain_core.prompts
import
PromptTemplate
...
@@ -54,20 +54,31 @@ class GetSimilarityWithExt:
...
@@ -54,20 +54,31 @@ class GetSimilarityWithExt:
def
get_rerank_with_doc
(
self
,
reranker
:
BgeRerank
,
split_doc
:
list
,
top_k
=
5
):
def
get_rerank_with_doc
(
self
,
reranker
:
BgeRerank
,
split_doc
:
list
,
top_k
=
5
):
question
=
'
\n
'
.
join
(
self
.
question
)
question
=
'
\n
'
.
join
(
self
.
question
)
print
(
question
)
print
(
question
)
split_doc
.
extend
(
self
.
similarity_docs
)
rerank_docs1
=
reranker
.
compress_documents
(
split_doc
,
question
)
content_set
=
set
()
rerank_docs2
=
reranker
.
compress_documents
(
self
.
similarity_docs
,
question
)
unique_documents
=
[]
rerank_docs1_hash
=
[]
for
doc
in
split_doc
:
rerank_docs2_hash
=
[]
content
=
hash
(
doc
.
page_content
)
m
=
{}
if
content
not
in
content_set
:
for
doc
in
rerank_docs1
:
unique_documents
.
append
(
doc
)
m
[
hash
(
doc
.
page_content
)]
=
doc
content_set
.
add
(
content
)
rerank_docs1_hash
.
append
(
hash
(
doc
.
page_content
))
rerank_docs
=
reranker
.
compress_documents
(
unique_documents
,
question
)
for
doc
in
rerank_docs2
:
m
[
hash
(
doc
.
page_content
)]
=
doc
rerank_docs2_hash
.
append
(
hash
(
doc
.
page_content
))
result
=
[]
result
.
append
((
60
,
rerank_docs1_hash
))
result
.
append
((
55
,
rerank_docs2_hash
))
print
(
len
(
rerank_docs1_hash
))
print
(
len
(
rerank_docs2_hash
))
rrf_doc
=
reciprocal_rank_fusion
(
result
)
print
(
rrf_doc
)
d_list
=
[]
d_list
=
[]
for
d
in
rerank_docs
[:
top_k
]:
for
key
in
rrf_doc
:
d_list
.
append
(
d
)
d_list
.
append
(
m
[
key
])
self
.
rerank_docs
=
rerank_docs
[:
top_k
]
return
self
.
faiss_db
.
join_document
(
d_list
)
self
.
rerank_docs
=
d_list
[:
top_k
]
return
self
.
faiss_db
.
join_document
(
d_list
[:
top_k
])
def
get_similarity_doc
(
self
):
def
get_similarity_doc
(
self
):
return
self
.
similarity_doc_txt
return
self
.
similarity_doc_txt
...
...
src/server/rerank.py
View file @
9c21eae7
...
@@ -115,11 +115,11 @@ def reciprocal_rank_fusion(results: list[set]):
...
@@ -115,11 +115,11 @@ def reciprocal_rank_fusion(results: list[set]):
]
]
# for TEST (print reranked documentsand scores)
# for TEST (print reranked documentsand scores)
print
(
"Reranked documents: "
,
len
(
reranked_results
))
#
print("Reranked documents: ", len(reranked_results))
for
doc
in
reranked_results
:
#
for doc in reranked_results:
print
(
'---'
)
#
print('---')
print
(
'Docs: '
,
' '
.
join
(
doc
[
0
]
.
page_content
[:
100
]
.
split
()))
#
print('Docs: ', ' '.join(doc[0].page_content[:100].split()))
print
(
'RRF score: '
,
doc
[
1
])
#
print('RRF score: ', doc[1])
# return only documents
# return only documents
return
[
x
[
0
]
for
x
in
reranked_results
[:
MAX_DOCS_FOR_CONTEXT
]]
return
[
x
[
0
]
for
x
in
reranked_results
[:
MAX_DOCS_FOR_CONTEXT
]]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment