Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
5f4098e7
Commit
5f4098e7
authored
8 months ago
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
监测点工具更新
parent
ca5f0328
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
139 additions
and
64 deletions
+139
-64
__init__.py
src/__init__.py
+0
-4
http_tools.py
src/agent/http_tools.py
+50
-14
tool_monitor.py
src/agent/tool_monitor.py
+52
-32
api.py
src/controller/api.py
+18
-3
logger.py
src/utils/logger.py
+2
-2
run_tool_picker_monitor.py
test/run_tool_picker_monitor.py
+17
-9
No files found.
src/__init__.py
View file @
5f4098e7
from
.utils.logger
import
setup_logging
# 设置日志配置
setup_logging
()
This diff is collapsed.
Click to expand it.
src/agent/http_tools.py
View file @
5f4098e7
...
@@ -11,9 +11,12 @@ const_base_url = "http://localhost:5001"
...
@@ -11,9 +11,12 @@ const_base_url = "http://localhost:5001"
const_url_point
=
"/cigem/getMonitorPointAll"
const_url_point
=
"/cigem/getMonitorPointAll"
const_url_rate
=
"/cigem/getAvgOnlineRate"
const_url_rate
=
"/cigem/getAvgOnlineRate"
const_url_rate_ranking
=
"/cigem/getOnlineRateRank"
const_url_rate_ranking
=
"/cigem/getOnlineRateRank"
const_url_rate_month
=
"/cigem/get
Avg
OnlineRateOfMonth"
const_url_rate_month
=
"/cigem/getOnlineRateOfMonth"
const_url_device_list
=
"/cigem/getMonitorDeviceList"
const_url_device_list
=
"/cigem/getMonitorDeviceList"
const_url_warning
=
"/cigem/getWarningStatistics"
const_url_device_and_sensor
=
"/cigem/getDeviceAndSensorCount"
const_url_warning
=
"/cigem/getWarnMsgDisposeRate"
const_url_warning_month
=
"/cigem/getWarnMsgDisposeRateOfMonth"
class
BaseResponse
(
BaseModel
,
Generic
[
T
]):
class
BaseResponse
(
BaseModel
,
Generic
[
T
]):
"""通用响应模型"""
"""通用响应模型"""
...
@@ -93,12 +96,27 @@ class MonitorClient(BaseHttpClient):
...
@@ -93,12 +96,27 @@ class MonitorClient(BaseHttpClient):
)
)
return
BaseResponse
[
List
](
**
data
)
return
BaseResponse
[
List
](
**
data
)
def
query_points_sync
(
self
,
key
:
str
)
->
BaseResponse
[
List
]:
def
query_points_sync
(
self
,
key
:
str
,
year
:
str
,
monitor_type
:
str
,
three_d_model
:
str
,
ortho_image
:
str
,
disaster_threat_people_range_start
:
str
,
disaster_threat_people_range_end
:
str
,
disaster_scale_start
:
str
,
disaster_scale_end
:
str
,
device_type
:
str
)
->
BaseResponse
[
List
]:
"""同步查询监测点信息"""
"""同步查询监测点信息"""
params
=
{
"key"
:
key
,
"year"
:
year
,
"MONITORTYPE"
:
monitor_type
,
"MODELEXIST"
:
three_d_model
,
"DOMEXIST"
:
ortho_image
,
"STARTTHREATSPOPULATION"
:
disaster_threat_people_range_start
,
"ENDTHREATSPOPULATION"
:
disaster_threat_people_range_end
,
"STARTDISASTERSCALE"
:
disaster_scale_start
,
}
print
(
f
"查询参数: {params}"
)
data
=
self
.
_request_sync
(
data
=
self
.
_request_sync
(
"POST"
,
"POST"
,
const_url_point
,
const_url_point
,
json
=
{
"key"
:
key
}
json
=
params
)
)
return
BaseResponse
[
List
](
**
data
)
return
BaseResponse
[
List
](
**
data
)
...
@@ -121,10 +139,24 @@ class MonitorClient(BaseHttpClient):
...
@@ -121,10 +139,24 @@ class MonitorClient(BaseHttpClient):
)
)
return
BaseResponse
[
List
](
**
data
)
return
BaseResponse
[
List
](
**
data
)
def
query_device_and_sensor
(
self
,
area_code
:
str
,
start_time
:
str
,
end_time
:
str
,
device_type
:
str
)
->
BaseResponse
[
List
]:
"""同步查询设备和传感器数量"""
data
=
self
.
_request_sync
(
"POST"
,
const_url_device_and_sensor
,
json
=
{
"startTime"
:
start_time
,
"endTime"
:
end_time
,
"areaCode"
:
area_code
,
"deviceType"
:
device_type
}
)
return
BaseResponse
[
List
](
**
data
)
# 示例:添加新的数据接口客户端
# 示例:添加新的数据接口客户端
class
RateClient
(
BaseHttpClient
):
class
RateClient
(
BaseHttpClient
):
"""在线率查询客户端"""
"""在线率查询客户端"""
async
def
query_rates
(
self
,
areacode
:
str
,
startDate
:
str
,
endDate
:
str
)
->
BaseResponse
[
List
]:
async
def
query_rates
(
self
,
areacode
:
str
,
startDate
:
str
,
endDate
:
str
,
manufacturer_name
:
str
,
typeArr
:
str
)
->
BaseResponse
[
List
]:
"""异步查询在线率信息"""
"""异步查询在线率信息"""
data
=
await
self
.
_request_async
(
data
=
await
self
.
_request_async
(
"POST"
,
"POST"
,
...
@@ -132,12 +164,14 @@ class RateClient(BaseHttpClient):
...
@@ -132,12 +164,14 @@ class RateClient(BaseHttpClient):
json
=
{
json
=
{
'areaCode'
:
areacode
,
'areaCode'
:
areacode
,
'startDate'
:
startDate
,
'startDate'
:
startDate
,
'endDate'
:
endDate
'endDate'
:
endDate
,
'manufacturerName'
:
manufacturer_name
,
'typeArr'
:
typeArr
}
}
)
)
return
BaseResponse
[
List
](
**
data
)
return
BaseResponse
[
List
](
**
data
)
def
query_rates_sync
(
self
,
areacode
:
str
,
startDate
:
str
,
endDate
:
str
)
->
BaseResponse
[
List
]:
def
query_rates_sync
(
self
,
areacode
:
str
,
startDate
:
str
,
endDate
:
str
,
manufacturer_name
:
str
,
typeArr
:
str
)
->
BaseResponse
[
List
]:
"""同步查询在线率信息"""
"""同步查询在线率信息"""
data
=
self
.
_request_sync
(
data
=
self
.
_request_sync
(
"POST"
,
"POST"
,
...
@@ -145,7 +179,9 @@ class RateClient(BaseHttpClient):
...
@@ -145,7 +179,9 @@ class RateClient(BaseHttpClient):
json
=
{
json
=
{
'areaCode'
:
areacode
,
'areaCode'
:
areacode
,
'startDate'
:
startDate
,
'startDate'
:
startDate
,
'endDate'
:
endDate
'endDate'
:
endDate
,
'manufacturerName'
:
manufacturer_name
,
'typeArr'
:
typeArr
}
}
)
)
return
BaseResponse
[
List
](
**
data
)
return
BaseResponse
[
List
](
**
data
)
...
@@ -168,7 +204,7 @@ class RateClient(BaseHttpClient):
...
@@ -168,7 +204,7 @@ class RateClient(BaseHttpClient):
)
)
return
BaseResponse
[
List
](
**
data
)
return
BaseResponse
[
List
](
**
data
)
def
query_rates_month_sync
(
self
,
year
:
str
,
areaCode
:
str
,
typeArr
:
List
[
int
]
)
->
BaseResponse
[
List
]:
def
query_rates_month_sync
(
self
,
year
:
str
,
areaCode
:
str
,
typeArr
:
str
)
->
BaseResponse
[
List
]:
"""同步查询按月度统计的在线率信息"""
"""同步查询按月度统计的在线率信息"""
data
=
self
.
_request_sync
(
data
=
self
.
_request_sync
(
"POST"
,
"POST"
,
...
@@ -184,15 +220,15 @@ class RateClient(BaseHttpClient):
...
@@ -184,15 +220,15 @@ class RateClient(BaseHttpClient):
class
WarningClient
(
BaseHttpClient
):
class
WarningClient
(
BaseHttpClient
):
"""预警查询客户端"""
"""预警查询客户端"""
def
query_warning_statistics
(
self
,
start
Date
:
str
,
endDate
:
str
,
areaC
ode
:
str
)
->
BaseResponse
[
List
]:
def
query_warning_statistics
(
self
,
start
_time
:
str
,
end_time
:
str
,
area_c
ode
:
str
)
->
BaseResponse
[
List
]:
"""同步查询预警统计信息"""
"""同步查询预警统计信息"""
data
=
self
.
_request_sync
(
data
=
self
.
_request_sync
(
"POST"
,
"POST"
,
const_url_warning
,
const_url_warning
,
json
=
{
json
=
{
'start
Date'
:
startDat
e
,
'start
Time'
:
start_tim
e
,
'end
Date'
:
endDat
e
,
'end
Time'
:
end_tim
e
,
'areaCode'
:
area
C
ode
'areaCode'
:
area
_c
ode
}
}
)
)
return
BaseResponse
[
List
](
**
data
)
return
BaseResponse
[
List
](
**
data
)
...
@@ -201,7 +237,7 @@ class WarningClient(BaseHttpClient):
...
@@ -201,7 +237,7 @@ class WarningClient(BaseHttpClient):
"""同步查询按月度统计的预警统计信息"""
"""同步查询按月度统计的预警统计信息"""
data
=
self
.
_request_sync
(
data
=
self
.
_request_sync
(
"POST"
,
"POST"
,
const_url_warning
,
const_url_warning
_month
,
json
=
{
json
=
{
'year'
:
year
,
'year'
:
year
,
'areaCode'
:
areaCode
'areaCode'
:
areaCode
...
...
This diff is collapsed.
Click to expand it.
src/agent/tool_monitor.py
View file @
5f4098e7
...
@@ -15,14 +15,15 @@ class MonitorPointResponse():
...
@@ -15,14 +15,15 @@ class MonitorPointResponse():
class
MonitorPointArgs
(
BaseModel
):
class
MonitorPointArgs
(
BaseModel
):
"""监测点查询参数"""
"""监测点查询参数"""
key
:
str
=
Field
(
...
,
description
=
"行政区划名称(省/市级别均可,只需要最后一级,如长沙市,不需要湖南省)"
)
key
:
str
=
Field
(
...
,
description
=
"行政区划名称(省/市级别均可,只需要最后一级,如长沙市,不需要湖南省)"
)
year
:
int
=
Field
(
None
,
description
=
"年份,未提及则为今年"
)
year
:
str
=
Field
(
""
,
description
=
"年份,未提及则为空"
)
disaster_type
:
str
=
Field
(
None
,
description
=
"灾害类型,如崩塌、滑坡、泥石流、地面塌陷、地面沉降、地裂缝等,未提及则为空"
)
disaster_type
:
str
=
Field
(
""
,
description
=
"灾害类型,如崩塌、滑坡、泥石流、地面塌陷、地面沉降、地裂缝等,未提及则为空"
)
three_d_model
:
bool
=
Field
(
False
,
description
=
"是否需要三维模型,默认不需要"
)
three_d_model
:
str
=
Field
(
""
,
description
=
"是否需要三维模型,需要为有,不需要为无,默认为空"
)
ortho_image
:
bool
=
Field
(
False
,
description
=
"是否需要正射影像,默认不需要"
)
ortho_image
:
str
=
Field
(
""
,
description
=
"是否需要正射影像,需要为有,不需要为无,默认为空"
)
disaster_threat_people_range_start
:
int
=
Field
(
None
,
description
=
"灾害威胁人数范围起始值,如100,未提及则为空"
)
disaster_threat_people_range_start
:
str
=
Field
(
""
,
description
=
"灾害威胁人数范围起始值,如100,未提及则为空"
)
disaster_threat_people_range_end
:
int
=
Field
(
None
,
description
=
"灾害威胁人数范围结束值,如200,未提及则为空"
)
disaster_threat_people_range_end
:
str
=
Field
(
""
,
description
=
"灾害威胁人数范围结束值,如200,未提及则为空"
)
disaster_scale
:
str
=
Field
(
None
,
description
=
"灾害规模,灾害为崩塌、滑坡、泥石流时表示体积,灾害为地面塌陷、地面沉降时表示面积,为地裂缝时表示长度,未提及则为空"
)
disaster_scale_start
:
str
=
Field
(
""
,
description
=
"灾害规模范围起始值,灾害为崩塌、滑坡、泥石流时表示体积,灾害为地面塌陷、地面沉降时表示面积,为地裂缝时表示长度,未提及则为空"
)
device_type
:
str
=
Field
(
None
,
description
=
"设备类型(例如 加速度、位移、温度、湿度、裂缝计等),默认为空"
)
disaster_scale_end
:
str
=
Field
(
""
,
description
=
"灾害规模范围结束值,灾害为崩塌、滑坡、泥石流时表示体积,灾害为地面塌陷、地面沉降时表示面积,为地裂缝时表示长度,未提及则为空"
)
device_type
:
str
=
Field
(
""
,
description
=
"设备类型(例如 加速度、位移、温度、湿度、裂缝计等),默认为空"
)
class
MonitorPointTool
(
BaseTool
):
class
MonitorPointTool
(
BaseTool
):
"""查询监测点信息的工具"""
"""查询监测点信息的工具"""
...
@@ -49,7 +50,10 @@ class MonitorPointTool(BaseTool):
...
@@ -49,7 +50,10 @@ class MonitorPointTool(BaseTool):
self
.
client
=
MonitorClient
(
base_url
=
base_url
)
self
.
client
=
MonitorClient
(
base_url
=
base_url
)
self
.
logger
=
get_logger
(
"MonitorPointTool"
)
self
.
logger
=
get_logger
(
"MonitorPointTool"
)
def
_run
(
self
,
key
:
str
,
device_required
:
bool
=
False
,
device_type
:
str
=
None
)
->
Dict
[
str
,
Any
]:
def
_run
(
self
,
key
:
str
,
year
:
str
=
""
,
disaster_type
:
str
=
""
,
three_d_model
:
str
=
""
,
ortho_image
:
str
=
""
,
disaster_threat_people_range_start
:
str
=
""
,
disaster_threat_people_range_end
:
str
=
""
,
disaster_scale_start
:
str
=
""
,
disaster_scale_end
:
str
=
""
,
device_type
:
str
=
""
)
->
Dict
[
str
,
Any
]:
"""
"""
执行监测点查询
执行监测点查询
...
@@ -61,7 +65,8 @@ class MonitorPointTool(BaseTool):
...
@@ -61,7 +65,8 @@ class MonitorPointTool(BaseTool):
ortho_image: 是否需要正射影像
ortho_image: 是否需要正射影像
disaster_threat_people_range_start: 灾害威胁人数范围起始值
disaster_threat_people_range_start: 灾害威胁人数范围起始值
disaster_threat_people_range_end: 灾害威胁人数范围结束值
disaster_threat_people_range_end: 灾害威胁人数范围结束值
disaster_scale: 灾害规模
disaster_scale_start: 灾害规模范围起始值
disaster_scale_end: 灾害规模范围结束值
device_required: 是否需要设备相关信息
device_required: 是否需要设备相关信息
device_type: 设备类型
device_type: 设备类型
Returns:
Returns:
...
@@ -69,18 +74,21 @@ class MonitorPointTool(BaseTool):
...
@@ -69,18 +74,21 @@ class MonitorPointTool(BaseTool):
"""
"""
try
:
try
:
self
.
logger
.
info
(
f
"开始查询监测点信息,区域: {key}"
)
self
.
logger
.
info
(
f
"开始查询监测点信息,区域: {key}"
)
code
=
""
#
code = ""
if
key
!=
""
:
#
if key != "":
self
.
logger
.
debug
(
f
"查找区域代码: {key}"
)
#
self.logger.debug(f"查找区域代码: {key}")
codes
=
code_tool
.
find_code
(
key
)
#
codes = code_tool.find_code(key)
if
codes
is
None
or
len
(
codes
)
==
0
:
#
if codes is None or len(codes) == 0:
error_msg
=
f
'未找到匹配的区域代码: {key}'
#
error_msg = f'未找到匹配的区域代码: {key}'
self
.
logger
.
warning
(
error_msg
)
#
self.logger.warning(error_msg)
return
{
'code'
:
400
,
'message'
:
error_msg
}
#
return {'code': 400, 'message': error_msg}
code
=
codes
[
0
][
1
]
#
code = codes[0][1]
self
.
logger
.
debug
(
f
"找到区域代码: {code}"
)
#
self.logger.debug(f"找到区域代码: {code}")
response
=
self
.
client
.
query_points_sync
(
key
)
response
=
self
.
client
.
query_points_sync
(
key
,
year
,
disaster_type
,
three_d_model
,
ortho_image
,
disaster_threat_people_range_start
,
disaster_threat_people_range_end
,
disaster_scale_start
,
disaster_scale_end
,
device_type
)
self
.
logger
.
debug
(
f
"API响应: {response}"
)
self
.
logger
.
debug
(
f
"API响应: {response}"
)
if
response
.
type
!=
1
or
len
(
response
.
resultdata
)
==
0
:
if
response
.
type
!=
1
or
len
(
response
.
resultdata
)
==
0
:
...
@@ -91,30 +99,42 @@ class MonitorPointTool(BaseTool):
...
@@ -91,30 +99,42 @@ class MonitorPointTool(BaseTool):
'message'
:
error_msg
'message'
:
error_msg
}
}
# 提取关键信息并格式化
# 提取关键信息并格式化
points_info
=
[]
points_info
=
[]
for
point
in
response
.
resultdata
:
for
point
in
response
.
resultdata
:
point_data
=
{
point_data
=
{
"名称"
:
f
"{point['MONITORPOINTNAME']}"
if
point
[
"MONITORPOINTNAME"
]
else
""
,
"监测点编号"
:
f
"{point['MONITORPOINTCODE']}"
if
point
[
"MONITORPOINTCODE"
]
else
""
,
"位置"
:
f
"{point['LOCATION']}"
if
point
[
"LOCATION"
]
else
""
,
"监测点名称"
:
f
"{point['MONITORPOINTNAME']}"
if
point
[
"MONITORPOINTNAME"
]
else
""
,
"地理位置"
:
f
"{point['LOCATION']}"
if
point
[
"LOCATION"
]
else
""
,
"经度"
:
f
"{point['LONGITUDE']}"
if
point
[
"LONGITUDE"
]
else
""
,
"经度"
:
f
"{point['LONGITUDE']}"
if
point
[
"LONGITUDE"
]
else
""
,
"纬度"
:
f
"{point['LATITUDE']}"
if
point
[
"LATITUDE"
]
else
""
,
"纬度"
:
f
"{point['LATITUDE']}"
if
point
[
"LATITUDE"
]
else
""
,
"海拔"
:
f
"{point['ELEVATION']}"
if
point
[
"ELEVATION"
]
else
""
,
"高程"
:
f
"{point['ELEVATION']}"
if
point
[
"ELEVATION"
]
else
""
,
"建设单位"
:
f
"{point['BUILDUNIT']}"
if
point
[
"BUILDUNIT"
]
else
""
,
"监测责任部门"
:
f
"{point['MONITORUNIT']}"
if
point
[
"MONITORUNIT"
]
else
""
,
"监测单位"
:
f
"{point['MONITORUNIT']}"
if
point
[
"MONITORUNIT"
]
else
""
,
"监测建设部门"
:
f
"{point['BUILDUNIT']}"
if
point
[
"BUILDUNIT"
]
else
""
,
"监测类型"
:
f
"{point['MONITORTYPE']}"
if
point
[
"MONITORTYPE"
]
else
""
"监测运维部门"
:
f
"{point['YWUNIT']}"
if
point
[
"YWUNIT"
]
else
""
,
"设备厂商"
:
f
"{point['MANUFACTURER']}"
if
point
[
"MANUFACTURER"
]
else
""
,
"灾害类型"
:
f
"{point['MONITORTYPE']}"
if
point
[
"MONITORTYPE"
]
else
""
,
"有无三维模型"
:
f
"{point['MODELEXIST']}"
if
point
[
"MODELEXIST"
]
else
""
,
"有无正射影像"
:
f
"{point['DOMEXIST']}"
if
point
[
"DOMEXIST"
]
else
""
,
"威胁人数"
:
f
"{point['THREATSPOPULATION']}"
if
point
[
"THREATSPOPULATION"
]
else
""
,
"规模等级"
:
f
"{point['DISASTERSCALE']}"
if
point
[
"DISASTERSCALE"
]
else
""
}
}
if
point
.
get
(
"SGDW"
)
or
point
.
get
(
"SGDW"
)
!=
"null"
:
point_data
[
"施工单位"
]
=
point
[
"SGDW"
]
if
point
.
get
(
"WARNLEVEL"
)
or
point
.
get
(
"WARNLEVEL"
)
!=
"null"
:
point_data
[
"预警等级"
]
=
point
[
"WARNLEVEL"
]
points_info
.
append
(
point_data
)
points_info
.
append
(
point_data
)
self
.
logger
.
debug
(
f
"处理监测点数据: {point_data['名称']} {point_data}"
)
self
.
logger
.
debug
(
f
"处理监测点数据: {point_data['
监测点
名称']} {point_data}"
)
self
.
logger
.
info
(
f
"成功获取 {len(points_info)} 个监测点数据"
)
self
.
logger
.
info
(
f
"成功获取 {len(points_info)} 个监测点数据"
)
markdown
=
self
.
to_markdown
(
points_info
)
#
markdown = self.to_markdown(points_info)
result
=
{
result
=
{
'code'
:
200
,
'code'
:
200
,
'message'
:
f
"在{key}找到{len(points_info)}个监测点"
,
'message'
:
f
"在{key}找到{len(points_info)}个监测点
信息
"
,
'points'
:
points_info
,
'points'
:
points_info
,
'markdown'
:
markdown
#
'markdown': markdown
}
}
self
.
logger
.
info
(
"数据处理完成,返回结果"
)
self
.
logger
.
info
(
"数据处理完成,返回结果"
)
return
result
return
result
...
...
This diff is collapsed.
Click to expand it.
src/controller/api.py
View file @
5f4098e7
...
@@ -6,7 +6,7 @@ import argparse
...
@@ -6,7 +6,7 @@ import argparse
from
typing
import
Optional
from
typing
import
Optional
from
fastapi
import
FastAPI
,
Header
from
fastapi
import
FastAPI
,
Header
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
import
logging
import
uvicorn
import
uvicorn
...
@@ -14,7 +14,7 @@ from src.server.agent_rate import new_rate_agent, RateAgentV3
...
@@ -14,7 +14,7 @@ from src.server.agent_rate import new_rate_agent, RateAgentV3
from
src.server.classify
import
new_router_llm
from
src.server.classify
import
new_router_llm
from
src.server.rewrite
import
new_re_rewriter_llm
from
src.server.rewrite
import
new_re_rewriter_llm
from
src.controller.request
import
GeoAgentRateRequest
from
src.controller.request
import
GeoAgentRateRequest
from
src.utils.logger
import
setup_logging
from
langchain_openai
import
ChatOpenAI
from
langchain_openai
import
ChatOpenAI
# 默认配置
# 默认配置
...
@@ -159,6 +159,7 @@ def main():
...
@@ -159,6 +159,7 @@ def main():
parser
.
add_argument
(
"--api_base"
,
type
=
str
,
help
=
"OpenAI API基础地址"
)
parser
.
add_argument
(
"--api_base"
,
type
=
str
,
help
=
"OpenAI API基础地址"
)
parser
.
add_argument
(
"--tool_base_url"
,
type
=
str
,
help
=
"工具服务基础地址"
)
parser
.
add_argument
(
"--tool_base_url"
,
type
=
str
,
help
=
"工具服务基础地址"
)
parser
.
add_argument
(
"--api_key"
,
type
=
str
,
help
=
"OpenAI API密钥"
)
parser
.
add_argument
(
"--api_key"
,
type
=
str
,
help
=
"OpenAI API密钥"
)
parser
.
add_argument
(
"--log_level"
,
type
=
str
,
help
=
"日志级别,DEBUG、INFO、WARNING、ERROR、CRITICAL"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -169,7 +170,21 @@ def main():
...
@@ -169,7 +170,21 @@ def main():
api_base
=
get_config
(
"API_BASE"
,
args
)
api_base
=
get_config
(
"API_BASE"
,
args
)
tool_base_url
=
get_config
(
"TOOL_BASE_URL"
,
args
)
tool_base_url
=
get_config
(
"TOOL_BASE_URL"
,
args
)
api_key
=
get_config
(
"API_KEY"
,
args
)
api_key
=
get_config
(
"API_KEY"
,
args
)
log_level
=
get_config
(
"LOG_LEVEL"
,
args
)
if
log_level
is
not
None
:
if
log_level
==
"DEBUG"
:
log_level
=
logging
.
DEBUG
elif
log_level
==
"INFO"
:
log_level
=
logging
.
INFO
elif
log_level
==
"WARNING"
:
log_level
=
logging
.
WARNING
elif
log_level
==
"ERROR"
:
log_level
=
logging
.
ERROR
elif
log_level
==
"CRITICAL"
:
log_level
=
logging
.
CRITICAL
else
:
log_level
=
logging
.
INFO
setup_logging
(
log_level
=
log_level
)
# 初始化 agent
# 初始化 agent
agent_manager
.
initialize
(
agent_manager
.
initialize
(
api_key
=
api_key
,
api_key
=
api_key
,
...
...
This diff is collapsed.
Click to expand it.
src/utils/logger.py
View file @
5f4098e7
...
@@ -30,8 +30,8 @@ def setup_logging(log_level=logging.INFO, log_dir="logs"):
...
@@ -30,8 +30,8 @@ def setup_logging(log_level=logging.INFO, log_dir="logs"):
)
)
# 设置第三方库的日志级别
# 设置第三方库的日志级别
logging
.
getLogger
(
"httpx"
)
.
setLevel
(
logging
.
WARNING
)
#
logging.getLogger("httpx").setLevel(logging.WARNING)
logging
.
getLogger
(
"urllib3"
)
.
setLevel
(
logging
.
WARNING
)
#
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging
.
info
(
f
"日志配置完成,日志文件: {log_file}"
)
logging
.
info
(
f
"日志配置完成,日志文件: {log_file}"
)
...
...
This diff is collapsed.
Click to expand it.
test/run_tool_picker_monitor.py
View file @
5f4098e7
...
@@ -6,7 +6,7 @@ import sys,os
...
@@ -6,7 +6,7 @@ import sys,os
sys
.
path
.
append
(
"../"
)
sys
.
path
.
append
(
"../"
)
from
src.server.tool_picker
import
ToolPicker
from
src.server.tool_picker
import
ToolPicker
,
ToolRunner
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_rate
import
RegionRateTool
,
RankingRateTool
from
src.agent.tool_monitor
import
MonitorPointTool
from
src.agent.tool_monitor
import
MonitorPointTool
...
@@ -21,17 +21,20 @@ def run_examples():
...
@@ -21,17 +21,20 @@ def run_examples():
model_name
=
"Qwen2-7B"
,
model_name
=
"Qwen2-7B"
,
verbose
=
True
verbose
=
True
)
)
base_url
=
"http://172.30.0.37:30007"
# 初始化工具
# 初始化工具
tools
=
[
tools
=
[
RegionRateTool
(),
RegionRateTool
(
base_url
=
base_url
),
RankingRateTool
(),
RankingRateTool
(
base_url
=
base_url
),
MonitorPointTool
(),
MonitorPointTool
(
base_url
=
base_url
),
]
]
tools_dict
=
{
tool
.
name
:
tool
for
tool
in
tools
}
# 初始化 ToolPicker
# 初始化 ToolPicker
picker
=
ToolPicker
(
llm
,
tools
)
picker
=
ToolPicker
(
llm
,
tools
)
# 测试案例和预期结果
# 测试案例和预期结果
test_cases
=
[
test_cases
=
[
{
{
...
@@ -40,8 +43,8 @@ def run_examples():
...
@@ -40,8 +43,8 @@ def run_examples():
"tool"
:
"monitor_points_query"
,
"tool"
:
"monitor_points_query"
,
"params"
:
{
"params"
:
{
"key"
:
"甘肃省"
,
"key"
:
"甘肃省"
,
"three_d_model"
:
False
,
"three_d_model"
:
"无"
,
"ortho_image"
:
False
,
"ortho_image"
:
"无"
,
}
}
}
}
},{
},{
...
@@ -59,7 +62,7 @@ def run_examples():
...
@@ -59,7 +62,7 @@ def run_examples():
"tool"
:
"monitor_points_query"
,
"tool"
:
"monitor_points_query"
,
"params"
:
{
"params"
:
{
"key"
:
"甘肃省"
,
"key"
:
"甘肃省"
,
"three_d_model"
:
True
,
"three_d_model"
:
"有"
,
}
}
}
}
},{
},{
...
@@ -108,7 +111,12 @@ def run_examples():
...
@@ -108,7 +111,12 @@ def run_examples():
actual_value
,
actual_value
,
"✓"
if
expected_value
==
actual_value
else
"✗"
"✓"
if
expected_value
==
actual_value
else
"✗"
)
)
# run tool
tool
=
tools_dict
[
actual_tool
]
result
=
tool
.
invoke
(
result
[
"params"
])
print
(
result
)
except
Exception
as
e
:
except
Exception
as
e
:
table
.
add_row
(
"错误"
,
""
,
str
(
e
),
"✗"
)
table
.
add_row
(
"错误"
,
""
,
str
(
e
),
"✗"
)
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment