Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
L
LAE
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
文靖昊
LAE
Commits
da717251
Commit
da717251
authored
Nov 15, 2024
by
tinywell
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
工具数据单独处理为 markdown;增加日志便于排查
parent
45d5cb3a
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
278 additions
and
107 deletions
+278
-107
__init__.py
src/__init__.py
+4
-0
tool_monitor.py
src/agent/tool_monitor.py
+50
-19
tool_rate.py
src/agent/tool_rate.py
+133
-72
agent_rate.py
src/server/agent_rate.py
+2
-3
tool_picker.py
src/server/tool_picker.py
+40
-13
logger.py
src/utils/logger.py
+49
-0
No files found.
src/__init__.py
View file @
da717251
from
.utils.logger
import
setup_logging
# 设置日志配置
setup_logging
()
src/agent/tool_monitor.py
View file @
da717251
from
typing
import
Dict
,
Any
,
Optional
,
List
from
typing
import
Dict
,
Any
,
Optional
,
List
from
pydantic
import
BaseModel
,
Field
from
langchain_core.tools
import
BaseTool
from
.http_tools
import
MonitorClient
,
RateClient
from
typing
import
Type
from
..utils.logger
import
get_logger
class
MonitorPointResponse
():
"""监测点查询结果"""
status
:
str
=
Field
(
...
,
description
=
"状态"
)
...
...
@@ -23,6 +25,7 @@ class MonitorPointTool(BaseTool):
"""
args_schema
:
Type
[
BaseModel
]
=
MonitorPointArgs
client
:
Any
=
Field
(
None
,
exclude
=
True
)
logger
:
logging
.
Logger
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
base_url
:
str
=
"http://localhost:5001"
,
**
data
):
"""
...
...
@@ -34,6 +37,7 @@ class MonitorPointTool(BaseTool):
"""
super
()
.
__init__
(
**
data
)
self
.
client
=
MonitorClient
(
base_url
=
base_url
)
self
.
logger
=
get_logger
(
"MonitorPointTool"
)
def
_run
(
self
,
key
:
str
)
->
Dict
[
str
,
Any
]:
"""
...
...
@@ -46,39 +50,52 @@ class MonitorPointTool(BaseTool):
Dict: 包含查询结果的字典
"""
try
:
print
(
f
"进入 monitor_points_query 工具, 查询监测点信息
: {key}"
)
self
.
logger
.
info
(
f
"开始查询监测点信息,区域
: {key}"
)
response
=
self
.
client
.
query_points_sync
(
key
)
print
(
f
"查询结果: {response}"
)
self
.
logger
.
debug
(
f
"API响应: {response}"
)
if
response
.
type
!=
1
or
len
(
response
.
resultdata
)
==
0
:
error_msg
=
f
"查询失败: {response.message},请检查是否有相关数据权限"
self
.
logger
.
warning
(
error_msg
)
return
{
'code'
:
400
,
'message'
:
f
"查询失败: {response.message},请检查是否有相关数据权限"
'message'
:
error_msg
}
# 提取关键信息并格式化
points_info
=
[]
for
point
in
response
.
resultdata
:
points_info
.
append
({
"名称"
:
point
[
"MONITORPOINTNAME"
]
if
point
[
"MONITORPOINTNAME"
]
else
""
,
"位置"
:
point
[
"LOCATION"
]
if
point
[
"LOCATION"
]
else
""
,
"经度"
:
point
[
"LONGITUDE"
]
if
point
[
"LONGITUDE"
]
else
""
,
"纬度"
:
point
[
"LATITUDE"
]
if
point
[
"LATITUDE"
]
else
""
,
"海拔"
:
point
[
"ELEVATION"
]
if
point
[
"ELEVATION"
]
else
""
,
"建设单位"
:
point
[
"BUILDUNIT"
]
if
point
[
"BUILDUNIT"
]
else
""
,
"监测单位"
:
point
[
"MONITORUNIT"
]
if
point
[
"MONITORUNIT"
]
else
""
,
"监测类型"
:
point
[
"MONITORTYPE"
]
if
point
[
"MONITORTYPE"
]
else
""
})
point_data
=
{
"名称"
:
f
"{point['MONITORPOINTNAME']}"
if
point
[
"MONITORPOINTNAME"
]
else
""
,
"位置"
:
f
"{point['LOCATION']}"
if
point
[
"LOCATION"
]
else
""
,
"经度"
:
f
"{point['LONGITUDE']}"
if
point
[
"LONGITUDE"
]
else
""
,
"纬度"
:
f
"{point['LATITUDE']}"
if
point
[
"LATITUDE"
]
else
""
,
"海拔"
:
f
"{point['ELEVATION']}"
if
point
[
"ELEVATION"
]
else
""
,
"建设单位"
:
f
"{point['BUILDUNIT']}"
if
point
[
"BUILDUNIT"
]
else
""
,
"监测单位"
:
f
"{point['MONITORUNIT']}"
if
point
[
"MONITORUNIT"
]
else
""
,
"监测类型"
:
f
"{point['MONITORTYPE']}"
if
point
[
"MONITORTYPE"
]
else
""
}
points_info
.
append
(
point_data
)
self
.
logger
.
debug
(
f
"处理监测点数据: {point_data['名称']} {point_data}"
)
self
.
logger
.
info
(
f
"成功获取 {len(points_info)} 个监测点数据"
)
markdown
=
self
.
to_markdown
(
points_info
)
re
turn
{
re
sult
=
{
'code'
:
200
,
'message'
:
f
"在{key}找到{len(points_info)}个监测点"
,
'points'
:
points_info
'points'
:
points_info
,
'markdown'
:
markdown
}
self
.
logger
.
info
(
"数据处理完成,返回结果"
)
return
result
except
Exception
as
e
:
error_msg
=
f
"查询失败: {str(e)}"
self
.
logger
.
error
(
error_msg
,
exc_info
=
True
)
return
{
'code'
:
400
,
'message'
:
f
"查询失败: {str(e)}"
'message'
:
error_msg
}
def
_arun
(
self
,
key
:
str
)
->
Dict
[
str
,
Any
]:
...
...
@@ -91,4 +108,18 @@ class MonitorPointTool(BaseTool):
Returns:
Dict: 包含查询结果的字典
"""
raise
NotImplementedError
(
"异步查询暂未实现"
)
\ No newline at end of file
self
.
logger
.
warning
(
"异步查询方法未实现"
)
raise
NotImplementedError
(
"异步查询暂未实现"
)
def
to_markdown
(
self
,
data
:
List
[
Dict
[
str
,
Any
]])
->
str
:
"""将数据转换为 markdown 表格"""
self
.
logger
.
debug
(
"开始生成 markdown 表格"
)
markdown
=
"""
| 序号 | 名称 | 位置 | 经度 | 纬度 | 海拔 | 建设单位 | 监测单位 | 监测类型 |
| --- | --- | --- | --- | --- | --- | --- | --- | --- |
"""
for
index
,
row
in
enumerate
(
data
):
markdown
+=
f
"| {index+1} | {row['名称']} | {row['位置']} | {row['经度']} | {row['纬度']} | {row['海拔']} | {row['建设单位']} | {row['监测单位']} | {row['监测类型']} |
\n
"
self
.
logger
.
debug
(
"markdown 表格生成完成"
)
return
markdown
\ No newline at end of file
src/agent/tool_rate.py
View file @
da717251
...
...
@@ -6,27 +6,31 @@ from langchain_core.tools import BaseTool
from
.http_tools
import
RateClient
,
const_base_url
from
.code
import
AreaCodeTool
from
..utils.logger
import
get_logger
code_tool
=
AreaCodeTool
()
class
BaseRateTool
(
BaseTool
):
"""设备在线率分析基础工具类"""
logger
:
logging
.
Logger
=
Field
(
None
,
exclude
=
True
)
def
__init__
(
self
,
**
data
):
super
()
.
__init__
(
**
data
)
self
.
logger
=
get_logger
(
self
.
__class__
.
__name__
)
def
format_response
(
self
,
data
:
Dict
[
str
,
Any
],
chart
:
Any
)
->
Dict
[
str
,
Any
]:
"""格式化返回结果"""
self
.
logger
.
debug
(
"格式化返回结果"
)
return
{
'data'
:
data
,
# 'chart': chart,
'summary'
:
self
.
_generate_summary
(
data
)
}
def
_generate_summary
(
self
,
data
:
Dict
[
str
,
Any
])
->
str
:
"""生成数据分析总结文本"""
# 保持原有的summary生成逻辑
self
.
logger
.
debug
(
"生成数据分析总结"
)
if
'trend_data'
in
data
:
# 全国趋势数据
return
(
summary
=
(
f
"在{data['date_range']['start']}至{data['date_range']['end']}期间,"
f
"全国平均在线率为{data['statistics']['average_rate']:.2
%
},"
f
"最高达到{data['statistics']['max_rate']:.2
%
},"
...
...
@@ -35,27 +39,29 @@ class BaseRateTool(BaseTool):
)
elif
'rankings'
in
data
:
# 排名数据
if
'total_provinces'
in
data
:
# 省份排名
return
(
summary
=
(
f
"共分析了{data['total_provinces']}个省份的在线率数据,"
f
"平均在线率为{data['average_rate']:.2
%
}。"
f
"{data['best_province']['name']}的表现最好,"
f
"在线率达到{data['best_province']['rate']:.2
%
}。"
)
else
:
# 厂商排名
return
(
summary
=
(
f
"共分析了{data['total_manufacturers']}个厂商的在线率数据,"
f
"平均在线率为{data['average_rate']:.2
%
}。"
f
"{data['best_manufacturer']['name']}的表现最好,"
f
"在线率达到{data['best_manufacturer']['rate']:.2
%
}。"
)
else
:
# 地区在线率数据
return
(
summary
=
(
f
"{data['region']}在{data['date_range']['start']}至{data['date_range']['end']}期间,"
f
"平均在线率为{data['average_rate']:.2
%
},"
f
"最高达到{data['max_rate']:.2
%
},"
f
"最低为{data['min_rate']:.2
%
}。"
f
"平均设备数约{int(data['total_devices']):,}台。"
)
self
.
logger
.
debug
(
f
"生成的总结: {summary}"
)
return
summary
class
RegionRateArgs
(
BaseModel
):
"""地区在线率查询参数"""
...
...
@@ -73,48 +79,72 @@ class RegionRateTool(BaseRateTool):
def
__init__
(
self
,
base_url
:
str
=
const_base_url
,
**
data
):
super
()
.
__init__
(
**
data
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
self
.
logger
.
info
(
f
"初始化 RegionRateTool,base_url: {base_url}"
)
def
_run
(
self
,
start_time
:
str
,
end_time
:
str
,
region_name
:
str
=
""
)
->
Dict
[
str
,
Any
]:
return
self
.
get_region_online_rate
(
start_time
,
end_time
,
region_name
)
def
get_region_online_rate
(
self
,
start_time
:
str
,
end_time
:
str
,
region_name
:
str
=
""
)
->
Dict
[
str
,
Any
]:
# 查询数据
agent_start
=
time
.
time
()
print
(
f
"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}"
)
self
.
logger
.
info
(
f
"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}"
)
code
=
""
if
region_name
!=
""
:
self
.
logger
.
debug
(
f
"查找区域代码: {region_name}"
)
codes
=
code_tool
.
find_code
(
region_name
)
if
codes
is
None
or
len
(
codes
)
==
0
:
return
{
'code'
:
400
,
'message'
:
f
'未找到匹配的区域代码: {region_name}'
}
error_msg
=
f
'未找到匹配的区域代码: {region_name}'
self
.
logger
.
warning
(
error_msg
)
return
{
'code'
:
400
,
'message'
:
error_msg
}
code
=
codes
[
0
][
1
]
print
(
code
)
start
=
time
.
time
()
df
=
self
.
client
.
query_rates_sync
(
code
,
start_time
,
end_time
)
end
=
time
.
time
()
print
(
f
"query_rates_sync client spent time:{end-start}"
)
print
(
f
"地区在线率接口调用结果: {df}"
)
# 准备数据
if
df
.
type
!=
1
or
df
.
resultdata
is
None
or
len
(
df
.
resultdata
)
==
0
:
return
{
'code'
:
400
,
'message'
:
f
'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
}
print
(
f
"地区在线率查询结果: {df.resultdata}"
)
data
=
{
'region'
:
region_name
,
'region_code'
:
code
,
'rate_data'
:
df
.
resultdata
,
'date_range'
:
{
'start'
:
start_time
,
'end'
:
end_time
self
.
logger
.
debug
(
f
"找到区域代码: {code}"
)
try
:
start
=
time
.
time
()
df
=
self
.
client
.
query_rates_sync
(
code
,
start_time
,
end_time
)
query_time
=
time
.
time
()
-
start
self
.
logger
.
debug
(
f
"API调用耗时: {query_time:.2f}秒"
)
if
df
.
type
!=
1
or
df
.
resultdata
is
None
or
len
(
df
.
resultdata
)
==
0
:
error_msg
=
f
'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
self
.
logger
.
warning
(
error_msg
)
return
{
'code'
:
400
,
'message'
:
error_msg
}
self
.
logger
.
debug
(
f
"查询结果: {df.resultdata}"
)
markdown
=
self
.
to_markdown
(
df
.
resultdata
)
data
=
{
'region'
:
region_name
,
'region_code'
:
code
,
'rate_data'
:
df
.
resultdata
,
'markdown'
:
markdown
,
'date_range'
:
{
'start'
:
start_time
,
'end'
:
end_time
}
}
}
end
=
time
.
time
()
print
(
f
"once agent spent time:{end - agent_start}"
)
return
data
total_time
=
time
.
time
()
-
agent_start
self
.
logger
.
info
(
f
"查询完成,总耗时: {total_time:.2f}秒"
)
return
data
except
Exception
as
e
:
self
.
logger
.
error
(
f
"查询失败: {str(e)}"
,
exc_info
=
True
)
raise
def
to_markdown
(
self
,
data
:
List
[
Dict
[
str
,
Any
]])
->
str
:
"""将数据转换为 markdown 表格"""
self
.
logger
.
debug
(
"开始生成 markdown 表格"
)
markdown
=
"""
| 序号 | 日期 | 在线率 |
| --- | --- | --- |
"""
for
index
,
row
in
enumerate
(
data
):
markdown
+=
f
"| {index+1} | {row['name']} | {row['rate']} |
\n
"
self
.
logger
.
debug
(
"markdown 表格生成完成"
)
return
markdown
class
RankingRateArgs
(
BaseModel
):
"""排名查询参数"""
...
...
@@ -130,6 +160,7 @@ class RankingRateTool(BaseRateTool):
def
__init__
(
self
,
base_url
:
str
=
const_base_url
,
**
data
):
super
()
.
__init__
(
**
data
)
self
.
client
=
RateClient
(
base_url
=
base_url
)
self
.
logger
.
info
(
f
"初始化 RankingRateTool,base_url: {base_url}"
)
def
_run
(
self
,
rate_type
:
int
)
->
Dict
[
str
,
Any
]:
return
self
.
get_ranking_data
(
rate_type
)
...
...
@@ -142,44 +173,75 @@ class RankingRateTool(BaseRateTool):
def
_get_province_ranking
(
self
)
->
Dict
[
str
,
Any
]:
"""获取省份在线率排名"""
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
1
)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
if
df
.
type
!=
1
or
df
.
resultdata
is
None
or
len
(
df
.
resultdata
)
==
0
:
return
{
'code'
:
400
,
'message'
:
f
'未找到省份在线率排名数据,请检查是否有相关数据权限'
self
.
logger
.
info
(
"开始查询省份在线率排名"
)
try
:
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
1
)
if
df
.
type
!=
1
or
df
.
resultdata
is
None
or
len
(
df
.
resultdata
)
==
0
:
error_msg
=
'未找到省份在线率排名数据,请检查是否有相关数据权限'
self
.
logger
.
warning
(
error_msg
)
return
{
'code'
:
400
,
'message'
:
error_msg
}
self
.
logger
.
debug
(
f
"省份排名数据: {df.resultdata}"
)
markdown
=
self
.
to_markdown
(
df
.
resultdata
)
data
=
{
'rankings'
:
df
.
resultdata
,
'total_provinces'
:
len
(
df
.
resultdata
),
'best_province'
:
{
'name'
:
df
.
resultdata
[
0
][
'name'
],
'rate'
:
df
.
resultdata
[
0
][
'onlineRate'
]
},
'markdown'
:
markdown
}
print
(
f
"省份在线率排名数据: {df.resultdata}"
)
data
=
{
'rankings'
:
df
.
resultdata
,
'total_provinces'
:
len
(
df
.
resultdata
),
'best_province'
:
{
'name'
:
df
.
resultdata
[
0
][
'name'
],
'rate'
:
df
.
resultdata
[
0
][
'onlineRate'
]
}
}
return
data
self
.
logger
.
info
(
f
"查询完成,共 {len(df.resultdata)} 个省份"
)
return
data
except
Exception
as
e
:
self
.
logger
.
error
(
f
"查询省份排名失败: {str(e)}"
,
exc_info
=
True
)
raise
def
_get_manufacturer_ranking
(
self
)
->
Dict
[
str
,
Any
]:
"""获取厂商在线率排名"""
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
2
)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
if
df
.
type
!=
1
or
df
.
resultdata
is
None
or
len
(
df
.
resultdata
)
==
0
:
return
{
'code'
:
400
,
'message'
:
f
'未找到厂商在线率排名数据,请检查是否有相关数据权限'
}
print
(
f
"厂商在线率排名数据: {df.resultdata}"
)
data
=
{
'rankings'
:
df
.
resultdata
,
'total_manufacturers'
:
len
(
df
.
resultdata
),
'best_manufacturer'
:
{
'name'
:
df
.
resultdata
[
0
][
'name'
],
'rate'
:
df
.
resultdata
[
0
][
'onlineRate'
]
self
.
logger
.
info
(
"开始查询厂商在线率排名"
)
try
:
df
=
self
.
client
.
query_rates_ranking_sync
(
rank_type
=
2
)
if
df
.
type
!=
1
or
df
.
resultdata
is
None
or
len
(
df
.
resultdata
)
==
0
:
error_msg
=
'未找到厂商在线率排名数据,请检查是否有相关数据权限'
self
.
logger
.
warning
(
error_msg
)
return
{
'code'
:
400
,
'message'
:
error_msg
}
self
.
logger
.
debug
(
f
"厂商排名数据: {df.resultdata}"
)
markdown
=
self
.
to_markdown
(
df
.
resultdata
)
data
=
{
'rankings'
:
df
.
resultdata
,
'total_manufacturers'
:
len
(
df
.
resultdata
),
'best_manufacturer'
:
{
'name'
:
df
.
resultdata
[
0
][
'name'
],
'rate'
:
df
.
resultdata
[
0
][
'onlineRate'
]
},
'markdown'
:
markdown
}
}
self
.
logger
.
info
(
f
"查询完成,共 {len(df.resultdata)} 个厂商"
)
return
data
except
Exception
as
e
:
self
.
logger
.
error
(
f
"查询厂商排名失败: {str(e)}"
,
exc_info
=
True
)
raise
def
to_markdown
(
self
,
data
:
List
[
Dict
[
str
,
Any
]])
->
str
:
"""将数据转换为 markdown 表格"""
self
.
logger
.
debug
(
"开始生成 markdown 表格"
)
markdown
=
"""
| 序号 | 名称 | 全称 | 在线率 |
| --- | --- | --- | --- |
"""
for
index
,
row
in
enumerate
(
data
):
markdown
+=
f
"| {index+1} | {row['name']} | {row['fullname']} | {row['onlineRate']} |
\n
"
return
data
\ No newline at end of file
self
.
logger
.
debug
(
"markdown 表格生成完成"
)
return
markdown
src/server/agent_rate.py
View file @
da717251
...
...
@@ -185,10 +185,9 @@ class RateAgentV3:
def
run
(
self
,
input
:
str
):
picker_result
=
self
.
picker
.
pick
(
input
)
res
=
self
.
runner
.
run
(
input
,
picker_result
[
"tool"
],
picker_result
[
"params"
])
output
=
f
"相关数据如下:
\n
{res['table']}
\n\n
{res['output']}"
return
{
"input"
:
input
,
"output"
:
res
.
content
,
"tool"
:
picker_result
[
"tool"
],
"params"
:
picker_result
[
"params"
]
"output"
:
output
}
src/server/tool_picker.py
View file @
da717251
...
...
@@ -5,7 +5,7 @@ from langchain_core.prompts import PromptTemplate,ChatPromptTemplate,SystemMessa
from
langchain.tools.render
import
render_text_description_and_args
from
langchain_core.output_parsers
import
JsonOutputParser
as
JSONOutputParser
from
langchain_core.tools
import
BaseTool
from
..utils.logger
import
get_logger
PICKER_SYSTEM_PROMPT
=
"""
你是一个智能工具选择助手,你需要根据用户的问题选择最合适的工具,并提取出工具所需的参数。
...
...
@@ -38,6 +38,8 @@ class ToolPicker:
def
__init__
(
self
,
llm
,
tools
:
List
):
self
.
tools
=
tools
self
.
llm
=
llm
self
.
logger
=
get_logger
(
"ToolPicker"
)
date_now
=
datetime
.
now
()
.
strftime
(
"
%
Y-
%
m-
%
d"
)
picker_human
=
f
"今天是{date_now}
\n\n
{PICKER_HUMAN_PROMPT}"
prompt
=
ChatPromptTemplate
.
from_messages
([
...
...
@@ -50,9 +52,14 @@ class ToolPicker:
self
.
chain
=
prompt
|
self
.
llm
|
JSONOutputParser
()
def
pick
(
self
,
input
:
str
):
print
(
input
)
return
self
.
chain
.
invoke
({
"input"
:
input
})
self
.
logger
.
info
(
f
"Received input: {input}"
)
try
:
result
=
self
.
chain
.
invoke
({
"input"
:
input
})
self
.
logger
.
info
(
f
"Selected tool: {result['tool']} with params: {result['params']}"
)
return
result
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error picking tool: {str(e)}"
)
raise
RUNNER_SYSTEM_PROMPT
=
"""
你是一个擅长根据工具执行结果回答用户问题的助手。
...
...
@@ -62,10 +69,6 @@ RUNNER_SYSTEM_PROMPT = """
2. 根据用户的问题,解读工具执行结果,进行简要的分析说明
3. 返回用户问题的答案
请遵循以下规则:
- 工具执行结果中的数据必须使用 markdown 表格展示
- 确保数据的完整性, 不要遗漏数据
- 表格中的数据只能来源于工具执行结果
"""
...
...
@@ -79,6 +82,7 @@ class ToolRunner:
def
__init__
(
self
,
llm
,
tools
:
Dict
[
str
,
BaseTool
]):
self
.
tools
=
tools
self
.
llm
=
llm
self
.
logger
=
get_logger
(
"ToolRunner"
)
prompt
=
ChatPromptTemplate
.
from_messages
([
SystemMessagePromptTemplate
.
from_template
(
RUNNER_SYSTEM_PROMPT
),
...
...
@@ -86,11 +90,34 @@ class ToolRunner:
])
self
.
chain
=
prompt
|
self
.
llm
def
run
(
self
,
input
:
str
,
tool_name
:
str
,
params
:
Dict
):
self
.
logger
.
info
(
f
"Running tool '{tool_name}' with params: {params}"
)
if
tool_name
not
in
self
.
tools
:
self
.
logger
.
error
(
f
"Tool {tool_name} not found"
)
raise
ValueError
(
f
"Tool {tool_name} not found"
)
tool
=
self
.
tools
[
tool_name
]
result
=
tool
.
invoke
(
params
)
return
self
.
chain
.
invoke
({
"input"
:
input
,
"result"
:
result
})
try
:
tool
=
self
.
tools
[
tool_name
]
self
.
logger
.
info
(
f
"Invoking tool {tool_name}"
)
result
=
tool
.
invoke
(
params
)
self
.
logger
.
debug
(
f
"Tool result: {result}"
)
table
=
""
if
"markdown"
in
result
:
table
=
result
[
"markdown"
]
del
result
[
"markdown"
]
self
.
logger
.
info
(
"Getting LLM interpretation"
)
llm_result
=
self
.
chain
.
invoke
({
"input"
:
input
,
"result"
:
result
})
response
=
{
"output"
:
llm_result
.
content
,
"table"
:
table
}
self
.
logger
.
info
(
"Tool execution completed successfully"
)
return
response
except
Exception
as
e
:
self
.
logger
.
error
(
f
"Error running tool: {str(e)}"
)
raise
src/utils/logger.py
0 → 100644
View file @
da717251
import
logging
import
os
from
datetime
import
datetime
def
setup_logging
(
log_level
=
logging
.
INFO
,
log_dir
=
"logs"
):
"""
设置统一的日志配置
Args:
log_level: 日志级别,默认为 INFO
log_dir: 日志文件目录,默认为 logs
"""
# 创建日志目录
if
not
os
.
path
.
exists
(
log_dir
):
os
.
makedirs
(
log_dir
)
# 生成日志文件名,包含日期
log_file
=
os
.
path
.
join
(
log_dir
,
f
"app_{datetime.now().strftime('
%
Y
%
m
%
d')}.log"
)
# 配置根日志记录器
logging
.
basicConfig
(
level
=
log_level
,
format
=
'
%(asctime)
s -
%(name)
s -
%(levelname)
s -
%(message)
s'
,
handlers
=
[
# 输出到控制台
logging
.
StreamHandler
(),
# 输出到文件
logging
.
FileHandler
(
log_file
,
encoding
=
'utf-8'
)
]
)
# 设置第三方库的日志级别
logging
.
getLogger
(
"httpx"
)
.
setLevel
(
logging
.
WARNING
)
logging
.
getLogger
(
"urllib3"
)
.
setLevel
(
logging
.
WARNING
)
logging
.
info
(
f
"日志配置完成,日志文件: {log_file}"
)
def
get_logger
(
name
):
"""
获取指定名称的日志记录器
Args:
name: 日志记录器名称
Returns:
logging.Logger: 日志记录器实例
"""
return
logging
.
getLogger
(
name
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment