Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
cc83b288
Commit
cc83b288
authored
Jul 15, 2024
by
文靖昊
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
多个agent tool结合
parent
47159811
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
10 deletions
+32
-10
rag_agent.py
src/agent/rag_agent.py
+18
-10
get_similarity.py
src/server/get_similarity.py
+14
-0
No files found.
src/agent/rag_agent.py
View file @
cc83b288
...
...
@@ -15,6 +15,7 @@ from src.pgdb.knowledge.k_db import PostgresDB
from
src.pgdb.knowledge.txt_doc_table
import
TxtDoc
from
langchain_core.documents
import
Document
import
json
from
src.agent.tool_divisions
import
AdministrativeDivision
from
src.config.consts
import
(
RERANK_MODEL_PATH
,
CHAT_DB_USER
,
...
...
@@ -1555,12 +1556,12 @@ def jieba_split(text: str) -> str:
class
IssuanceArgs
(
BaseModel
):
question
:
str
=
Field
(
description
=
"对话问题"
)
history
:
list
=
Field
(
description
=
"历史对话记录"
)
history
:
str
=
Field
(
description
=
"历史对话记录"
)
class
RAGQuery
(
BaseTool
):
name
=
"rag_query"
description
=
"""
Query the geological information of corresponding provinces, cities, and counties. Users can query geological information related to specific provinces, cities, and counties
"""
description
=
"""
你是一个区(县)级的水文气象地质知识库,当咨询区(县)的水文气象地质等相关信息的时候,你可以提供数据和资料。注意,这个查询只能获取单个区(县)的水文气象地质等相关信息,当需要查询省市的详细信息时,需要获取改省市下的具体行政规划,并获取具体的区(县)的水文气象地质等相关信息
"""
args_schema
:
Type
[
BaseModel
]
=
IssuanceArgs
rerank
:
Any
# 替换 Any 为适当的类型
rerank_model
:
Any
# 替换 Any 为适当的类型
...
...
@@ -1581,7 +1582,7 @@ class RAGQuery(BaseTool):
def
get_similarity_with_ext_origin
(
self
,
_ext
):
return
GetSimilarityWithExt
(
_question
=
_ext
,
_faiss_db
=
self
.
faiss_db
)
def
_run
(
self
,
question
:
str
,
history
:
list
)
->
str
:
def
_run
(
self
,
question
:
str
,
history
:
str
)
->
str
:
split_str
=
jieba_split
(
question
)
split_list
=
[]
for
l
in
split_str
:
...
...
@@ -1593,18 +1594,15 @@ class RAGQuery(BaseTool):
d
=
Document
(
page_content
=
a
[
0
],
metadata
=
json
.
loads
(
a
[
1
]))
split_docs
.
append
(
d
)
print
(
split_docs
)
result
=
self
.
rerank
.
extend_query
(
question
,
history
)
result
=
self
.
rerank
.
extend_query
_with_str
(
question
,
history
)
matches
=
re
.
findall
(
r'"([^"]+)"'
,
result
.
content
)
if
len
(
matches
)
>
3
:
matches
=
matches
[:
3
]
print
(
matches
)
prompt
=
""
for
h
in
history
:
prompt
+=
"问:{}
\n
答:{}
\n\n
"
.
format
(
h
[
0
],
h
[
1
])
similarity
=
self
.
get_similarity_with_ext_origin
(
matches
)
# cur_similarity = similarity.get_rerank(self.rerank_model)
cur_similarity
=
similarity
.
get_rerank_with_doc
(
self
.
rerank_model
,
split_docs
)
cur_question
=
self
.
prompt
.
format
(
history
=
prompt
,
context
=
cur_similarity
,
question
=
question
)
cur_question
=
self
.
prompt
.
format
(
history
=
history
,
context
=
cur_similarity
,
question
=
question
)
return
cur_question
...
...
@@ -1631,7 +1629,7 @@ k_db.connect()
tools
=
[
RAGQuery
(
vecstore_faiss
,
ext
,
PromptTemplate
(
input_variables
=
[
"history"
,
"context"
,
"question"
],
template
=
prompt_enhancement_history_template
),
_db
=
TxtDoc
(
k_db
))]
tools
=
[
AdministrativeDivision
(),
RAGQuery
(
vecstore_faiss
,
ext
,
PromptTemplate
(
input_variables
=
[
"history"
,
"context"
,
"question"
],
template
=
prompt_enhancement_history_template
),
_db
=
TxtDoc
(
k_db
))]
input_variables
=
[
'agent_scratchpad'
,
'input'
,
'tool_names'
,
'tools'
]
input_types
=
{
'chat_history'
:
List
[
Union
[
langchain_core
.
messages
.
ai
.
AIMessage
,
langchain_core
.
messages
.
human
.
HumanMessage
,
langchain_core
.
messages
.
chat
.
ChatMessage
,
langchain_core
.
messages
.
system
.
SystemMessage
,
langchain_core
.
messages
.
function
.
FunctionMessage
,
langchain_core
.
messages
.
tool
.
ToolMessage
]]}
metadata
=
{
'lc_hub_owner'
:
'hwchase17'
,
'lc_hub_repo'
:
'structured-chat-agent'
,
'lc_hub_commit_hash'
:
'ea510f70a5872eb0f41a4e3b7bb004d5711dc127adee08329c664c6c8be5f13c'
}
...
...
@@ -1651,7 +1649,17 @@ prompt = ChatPromptTemplate(
agent
=
create_structured_chat_agent
(
llm
=
base_llm
,
tools
=
tools
,
prompt
=
prompt
)
agent_executor
=
AgentExecutor
(
agent
=
agent
,
tools
=
tools
,
verbose
=
True
,
handle_parsing_errors
=
True
)
res
=
agent_executor
.
invoke
({
"input"
:
"大通县"
})
history
=
[]
h1
=
[]
h1
.
append
(
"大通县年降雨量"
)
h1
.
append
(
"大通县年雨量平均为30ml"
)
history
.
append
(
h1
)
prompt
=
""
for
h
in
history
:
prompt
+=
"问:{}
\n
答:{}
\n\n
"
.
format
(
h
[
0
],
h
[
1
])
print
(
prompt
)
res
=
agent_executor
.
invoke
({
"input"
:
"以下历史对话记录: "
+
prompt
+
"以下是问题:"
+
"西宁市年平均降雨量"
})
print
(
"====== result: ======"
)
print
(
res
)
src/server/get_similarity.py
View file @
cc83b288
...
...
@@ -120,6 +120,20 @@ class QAExt:
history
+=
f
"Q: {msg[0]}
\n
A: {msg[1]}
\n
"
return
self
.
query_extend
.
invoke
(
input
=
{
"histories"
:
messages
,
"query"
:
question
})
def
extend_query_with_str
(
self
,
question
,
messages
):
"""
question: str
messages: list of tuple (str,str)
eg:
[
("Q1","A1"),
("Q2","A2"),
...
]
"""
return
self
.
query_extend
.
invoke
(
input
=
{
"histories"
:
messages
,
"query"
:
question
})
class
ChatExtend
:
def
__init__
(
self
,
llm
)
->
None
:
self
.
llm
=
llm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment