Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
df832ca9
Commit
df832ca9
authored
Nov 08, 2024
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
工具调试
parent
23a22fdc
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
80 additions
and
57 deletions
+80
-57
http_tools.py
src/agent/http_tools.py
+2
-1
tool_rate.py
src/agent/tool_rate.py
+33
-42
api.py
src/controller/api.py
+34
-7
agent_rate.py
src/server/agent_rate.py
+11
-7
No files found.
src/agent/http_tools.py
View file @
df832ca9
...
@@ -7,7 +7,8 @@ from urllib.parse import urljoin
...
@@ -7,7 +7,8 @@ from urllib.parse import urljoin
# 泛型类型定义
# 泛型类型定义
T
=
TypeVar
(
'T'
)
T
=
TypeVar
(
'T'
)
const_base_url
=
"http://172.30.0.37:30007"
# const_base_url = "http://172.30.0.37:30007"
const_base_url
=
"http://localhost:5001"
const_url_point
=
"/cigem/getMonitorPointAll"
const_url_point
=
"/cigem/getMonitorPointAll"
const_url_rate
=
"/cigem/getAvgOnlineRate"
const_url_rate
=
"/cigem/getAvgOnlineRate"
const_url_rate_ranking
=
"/cigem/getOnlineRateRank"
const_url_rate_ranking
=
"/cigem/getOnlineRateRank"
...
...
src/agent/tool_rate.py
View file @
df832ca9
...
@@ -62,12 +62,12 @@ class BaseRateTool(BaseTool):
...
@@ -62,12 +62,12 @@ class BaseRateTool(BaseTool):
class
RegionRateArgs
(
BaseModel
):
class
RegionRateArgs
(
BaseModel
):
"""地区在线率查询参数"""
"""地区在线率查询参数"""
region_name
:
str
=
Field
(
...
,
description
=
"地区名称"
)
region_name
:
str
=
Field
(
...
,
description
=
"地区名称
,如果要查询全国数据,请输入空字符串
"
)
start_time
:
str
=
Field
(
...
,
description
=
"开始时间 (YYYY-MM-DD)"
)
start_time
:
str
=
Field
(
...
,
description
=
"开始时间 (YYYY-MM-DD)"
)
end_time
:
str
=
Field
(
...
,
description
=
"结束时间 (YYYY-MM-DD)"
)
end_time
:
str
=
Field
(
...
,
description
=
"结束时间 (YYYY-MM-DD)"
)
class
RegionRateTool
(
BaseRateTool
):
class
RegionRateTool
(
BaseRateTool
):
"""查询特定地区设备在线率的工具"""
"""查询
全国或者
特定地区设备在线率的工具"""
name
=
"region_online_rate"
name
=
"region_online_rate"
description
=
"查询指定地区在指定时间段内的设备在线率"
description
=
"查询指定地区在指定时间段内的设备在线率"
args_schema
:
Type
[
BaseModel
]
=
RegionRateArgs
args_schema
:
Type
[
BaseModel
]
=
RegionRateArgs
...
@@ -77,27 +77,33 @@ class RegionRateTool(BaseRateTool):
...
@@ -77,27 +77,33 @@ class RegionRateTool(BaseRateTool):
super
()
.
__init__
(
**
data
)
super
()
.
__init__
(
**
data
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
def
_run
(
self
,
region_name
:
str
,
start_time
:
str
,
end_time
:
str
)
->
Dict
[
str
,
Any
]:
def
_run
(
self
,
start_time
:
str
,
end_time
:
str
,
region_name
:
str
=
""
)
->
Dict
[
str
,
Any
]:
return
self
.
get_region_online_rate
(
region_name
,
start_time
,
end_ti
me
)
return
self
.
get_region_online_rate
(
start_time
,
end_time
,
region_na
me
)
def
get_region_online_rate
(
self
,
region_name
:
str
,
start_time
:
str
,
end_time
:
str
)
->
Dict
[
str
,
Any
]:
def
get_region_online_rate
(
self
,
start_time
:
str
,
end_time
:
str
,
region_name
:
str
=
""
)
->
Dict
[
str
,
Any
]:
# 查询数据
# 查询数据
code
=
code_tool
.
find_code
(
region_name
)
print
(
f
"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}"
)
if
not
code
:
code
=
""
if
region_name
!=
""
:
codes
=
code_tool
.
find_code
(
region_name
)
if
codes
is
None
or
len
(
codes
)
==
0
:
return
{
return
{
'code'
:
400
,
'code'
:
400
,
'message'
:
f
'未找到匹配的区域代码: {region_name}'
'message'
:
f
'未找到匹配的区域代码: {region_name}'
}
}
df
=
self
.
client
.
query_rates_sync
({
code
=
codes
[
0
][
1
]
'startDate'
:
start_time
,
df
=
self
.
client
.
query_rates_sync
(
code
,
start_time
,
end_time
)
'endDate'
:
end_time
,
print
(
f
"地区在线率接口调用结果: {df}"
)
'areaCode'
:
code
[
0
][
1
]
})
# 准备数据
# 准备数据
if
df
.
type
!=
1
or
len
(
df
.
resultdata
)
==
0
:
return
{
'code'
:
400
,
'message'
:
f
'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
}
print
(
f
"地区在线率查询结果: {df.resultdata}"
)
data
=
{
data
=
{
'region'
:
region_name
,
'region'
:
region_name
,
'region_code'
:
code
[
0
][
1
]
,
'region_code'
:
code
,
'rate_data'
:
df
.
resultdata
,
'rate_data'
:
df
.
resultdata
,
'date_range'
:
{
'date_range'
:
{
'start'
:
start_time
,
'start'
:
start_time
,
...
@@ -136,6 +142,12 @@ class RankingRateTool(BaseRateTool):
...
@@ -136,6 +142,12 @@ class RankingRateTool(BaseRateTool):
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
1
)
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
1
)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
# 准备数据
if
df
.
type
!=
1
or
len
(
df
.
resultdata
)
==
0
:
return
{
'code'
:
400
,
'message'
:
f
'未找到省份在线率排名数据,请检查是否有相关数据权限'
}
print
(
f
"省份在线率排名数据: {df.resultdata}"
)
data
=
{
data
=
{
'rankings'
:
df
.
resultdata
,
'rankings'
:
df
.
resultdata
,
'total_provinces'
:
len
(
df
.
resultdata
),
'total_provinces'
:
len
(
df
.
resultdata
),
...
@@ -150,9 +162,14 @@ class RankingRateTool(BaseRateTool):
...
@@ -150,9 +162,14 @@ class RankingRateTool(BaseRateTool):
def
_get_manufacturer_ranking
(
self
)
->
Dict
[
str
,
Any
]:
def
_get_manufacturer_ranking
(
self
)
->
Dict
[
str
,
Any
]:
"""获取厂商在线率排名"""
"""获取厂商在线率排名"""
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
2
)
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
2
)
print
(
"厂商数据:"
,
df
.
resultdata
)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
# 准备数据
if
df
.
type
!=
1
or
len
(
df
.
resultdata
)
==
0
:
return
{
'code'
:
400
,
'message'
:
f
'未找到厂商在线率排名数据,请检查是否有相关数据权限'
}
print
(
f
"厂商在线率排名数据: {df.resultdata}"
)
data
=
{
data
=
{
'rankings'
:
df
.
resultdata
,
'rankings'
:
df
.
resultdata
,
'total_manufacturers'
:
len
(
df
.
resultdata
),
'total_manufacturers'
:
len
(
df
.
resultdata
),
...
@@ -163,29 +180,3 @@ class RankingRateTool(BaseRateTool):
...
@@ -163,29 +180,3 @@ class RankingRateTool(BaseRateTool):
}
}
return
data
return
data
\ No newline at end of file
class
NationalTrendArgs
(
BaseModel
):
"""全国趋势查询参数"""
start_time
:
str
=
Field
(
...
,
description
=
"开始时间 (YYYY-MM-DD)"
)
end_time
:
str
=
Field
(
...
,
description
=
"结束时间 (YYYY-MM-DD)"
)
class
NationalTrendTool
(
BaseRateTool
):
"""查询全国在线率趋势的工具"""
name
=
"national_online_trend"
description
=
"查询全国范围内设备在线率的变化趋势"
args_schema
:
Type
[
BaseModel
]
=
NationalTrendArgs
def
_run
(
self
,
start_time
:
str
,
end_time
:
str
)
->
Dict
[
str
,
Any
]:
return
self
.
get_national_trend
(
start_time
,
end_time
)
def
get_national_trend
(
self
,
start_time
:
str
,
end_time
:
str
)
->
Dict
[
str
,
Any
]:
"""
接口三:获取全国在线率趋势
"""
# 准备数据
data
=
{
}
return
self
.
format_response
(
data
,
fig
)
\ No newline at end of file
src/controller/api.py
View file @
df832ca9
...
@@ -18,13 +18,37 @@ app.add_middleware(
...
@@ -18,13 +18,37 @@ app.add_middleware(
allow_headers
=
[
"*"
],
# 允许所有HTTP头
allow_headers
=
[
"*"
],
# 允许所有HTTP头
)
)
global
base_llm
global
base_llm
,
tool_base_url
base_llm
=
None
base_llm
=
None
tool_base_url
=
None
class
AgentManager
:
def
__init__
(
self
):
self
.
llm
=
None
self
.
agent
=
None
def
initialize
(
self
,
api_key
:
str
,
api_base
:
str
,
model_name
:
str
,
tool_base_url
:
str
):
self
.
llm
=
ChatOpenAI
(
openai_api_key
=
api_key
,
openai_api_base
=
api_base
,
model_name
=
model_name
,
verbose
=
True
)
self
.
agent
=
new_rate_agent
(
self
.
llm
,
verbose
=
True
,
tool_base_url
=
tool_base_url
)
def
get_llm
(
self
):
return
self
.
llm
def
get_agent
(
self
):
return
self
.
agent
agent_manager
=
AgentManager
()
@app.post
(
'/api/agent/rate'
)
@app.post
(
'/api/agent/rate'
)
def
rate
(
chat_request
:
GeoAgentRateRequest
,
token
:
str
=
Header
(
None
)):
def
rate
(
chat_request
:
GeoAgentRateRequest
,
token
:
str
=
Header
(
None
)):
agent
=
new_rate_agent
(
base_llm
,
verbose
=
True
)
agent
=
agent_manager
.
get_agent
(
)
try
:
try
:
res
=
agent
.
exec
(
prompt_args
=
{
"input"
:
chat_request
.
query
})
res
=
agent
.
exec
(
prompt_args
=
{
"input"
:
chat_request
.
query
})
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -38,6 +62,8 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
...
@@ -38,6 +62,8 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
'data'
:
res
'data'
:
res
}
}
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
# 参数解析
# 参数解析
parser
=
argparse
.
ArgumentParser
(
description
=
"启动API服务"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"启动API服务"
)
...
@@ -45,13 +71,14 @@ if __name__ == "__main__":
...
@@ -45,13 +71,14 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
'0.0.0.0'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
'0.0.0.0'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--llm"
,
type
=
str
,
default
=
'Qwen2-7B'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--llm"
,
type
=
str
,
default
=
'Qwen2-7B'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--api_base"
,
type
=
str
,
default
=
'http://192.168.10.14:8000/v1'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--api_base"
,
type
=
str
,
default
=
'http://192.168.10.14:8000/v1'
,
help
=
"API服务地址"
)
parser
.
add_argument
(
"--tool_base_url"
,
type
=
str
,
default
=
'http://localhost:5001'
,
help
=
"API服务地址"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
base_llm
=
ChatOpenAI
(
agent_manager
.
initialize
(
openai_api_key
=
'xxxxxxxxxxxxx'
,
api_key
=
'xxxxxxxxxxxxx'
,
openai_api_base
=
args
.
api_base
,
api_base
=
args
.
api_base
,
model_name
=
args
.
llm
,
model_name
=
args
.
llm
,
verbose
=
True
tool_base_url
=
args
.
tool_base_url
)
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
)
uvicorn
.
run
(
app
,
host
=
args
.
host
,
port
=
args
.
port
)
src/server/agent_rate.py
View file @
df832ca9
...
@@ -14,7 +14,7 @@ from langchain.agents.format_scratchpad.openai_tools import (
...
@@ -14,7 +14,7 @@ from langchain.agents.format_scratchpad.openai_tools import (
from
langchain
import
hub
from
langchain
import
hub
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
,
NationalTrendTool
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_monitor
import
MonitorPointTool
from
src.agent.tool_monitor
import
MonitorPointTool
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
...
@@ -78,7 +78,7 @@ ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及
...
@@ -78,7 +78,7 @@ ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及
注意事项:
注意事项:
- 时间格式统一使用:YYYY-MM-DD
- 时间格式统一使用:YYYY-MM-DD
- 地区名称需要包含行政级别(如:福建省、厦门市)
- 地区名称需要包含行政级别(如:福建省、厦门市)
- 数据展示优先使用markdown 格式的数据表格,并配合文字说明
- 数据展示优先使用
markdown 格式的数据表格,并配合文字说明
- 百分比数据保留两位小数
- 百分比数据保留两位小数
您可以使用以下工具:
您可以使用以下工具:
...
@@ -119,7 +119,7 @@ Action:
...
@@ -119,7 +119,7 @@ Action:
你的回复格式为 Action:```$JSON_BLOB```然后 Observation。
你的回复格式为 Action:```$JSON_BLOB```然后 Observation。
"""
"""
PROMPT_AGENT_HUMAN
=
"""{input}
\n\n
{agent_scratchpad}
\n
(请注意,无论如何都要以 JSON 对象回复。
你的主要目标是帮助用户快速理解和分析设备在线率数据,提供准确、直观的分析结果
)"""
PROMPT_AGENT_HUMAN
=
"""{input}
\n\n
{agent_scratchpad}
\n
(请注意,无论如何都要以 JSON 对象回复。
工具返回的数据必须使用表格展示,包含在最终输出中
)"""
PROMPT_AGENT_SYS_VARS
=
[
"tool_names"
,
"tools"
]
PROMPT_AGENT_SYS_VARS
=
[
"tool_names"
,
"tools"
]
class
RateAgentV2
:
class
RateAgentV2
:
...
@@ -143,11 +143,15 @@ class RateAgentV2:
...
@@ -143,11 +143,15 @@ class RateAgentV2:
def
new_rate_agent
(
llm
,
verbose
:
bool
=
False
,
**
args
):
def
new_rate_agent
(
llm
,
verbose
:
bool
=
False
,
**
args
):
if
args
[
'tool_base_url'
]:
tool_base_url
=
args
[
'tool_base_url'
]
else
:
tool_base_url
=
const_base_url
tools
=
[
tools
=
[
RegionRateTool
(),
RegionRateTool
(
base_url
=
tool_base_url
),
RankingRateTool
(),
RankingRateTool
(
base_url
=
tool_base_url
),
NationalTrendTool
(),
MonitorPointTool
(
base_url
=
tool_base_url
)
MonitorPointTool
()
]
]
# 使用 LangChain 的工具调用代理
# 使用 LangChain 的工具调用代理
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment