Commit df832ca9 by tinywell

工具调试

parent 23a22fdc
......@@ -7,7 +7,8 @@ from urllib.parse import urljoin
# 泛型类型定义
T = TypeVar('T')
const_base_url = "http://172.30.0.37:30007"
# const_base_url = "http://172.30.0.37:30007"
const_base_url = "http://localhost:5001"
const_url_point = "/cigem/getMonitorPointAll"
const_url_rate = "/cigem/getAvgOnlineRate"
const_url_rate_ranking = "/cigem/getOnlineRateRank"
......
......@@ -62,12 +62,12 @@ class BaseRateTool(BaseTool):
class RegionRateArgs(BaseModel):
"""地区在线率查询参数"""
region_name: str = Field(..., description="地区名称")
region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串")
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
class RegionRateTool(BaseRateTool):
"""查询特定地区设备在线率的工具"""
"""查询全国或者特定地区设备在线率的工具"""
name = "region_online_rate"
description = "查询指定地区在指定时间段内的设备在线率"
args_schema: Type[BaseModel] = RegionRateArgs
......@@ -77,27 +77,33 @@ class RegionRateTool(BaseRateTool):
super().__init__(**data)
self.client = RateClient(base_url=base_url)
def _run(self, region_name: str, start_time: str, end_time: str) -> Dict[str, Any]:
return self.get_region_online_rate(region_name, start_time, end_time)
def _run(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]:
return self.get_region_online_rate(start_time, end_time, region_name)
def get_region_online_rate(self, region_name: str, start_time: str, end_time: str) -> Dict[str, Any]:
def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]:
# 查询数据
code = code_tool.find_code(region_name)
if not code:
print(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}")
code = ""
if region_name != "":
codes = code_tool.find_code(region_name)
if codes is None or len(codes) == 0:
return {
'code': 400,
'message': f'未找到匹配的区域代码: {region_name}'
}
df = self.client.query_rates_sync({
'startDate': start_time,
'endDate': end_time,
'areaCode': code[0][1]
})
code = codes[0][1]
df = self.client.query_rates_sync(code, start_time, end_time)
print(f"地区在线率接口调用结果: {df}")
# 准备数据
if df.type != 1 or len(df.resultdata) == 0:
return {
'code': 400,
'message': f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
}
print(f"地区在线率查询结果: {df.resultdata}")
data = {
'region': region_name,
'region_code': code[0][1],
'region_code': code,
'rate_data': df.resultdata,
'date_range': {
'start': start_time,
......@@ -136,6 +142,12 @@ class RankingRateTool(BaseRateTool):
df = self.client.query_rates_ranking_sync(rank_type=1)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
if df.type != 1 or len(df.resultdata) == 0:
return {
'code': 400,
'message': f'未找到省份在线率排名数据,请检查是否有相关数据权限'
}
print(f"省份在线率排名数据: {df.resultdata}")
data = {
'rankings': df.resultdata,
'total_provinces': len(df.resultdata),
......@@ -150,9 +162,14 @@ class RankingRateTool(BaseRateTool):
def _get_manufacturer_ranking(self) -> Dict[str, Any]:
"""获取厂商在线率排名"""
df = self.client.query_rates_ranking_sync(rank_type=2)
print("厂商数据:", df.resultdata)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据
if df.type != 1 or len(df.resultdata) == 0:
return {
'code': 400,
'message': f'未找到厂商在线率排名数据,请检查是否有相关数据权限'
}
print(f"厂商在线率排名数据: {df.resultdata}")
data = {
'rankings': df.resultdata,
'total_manufacturers': len(df.resultdata),
......@@ -163,29 +180,3 @@ class RankingRateTool(BaseRateTool):
}
return data
\ No newline at end of file
class NationalTrendArgs(BaseModel):
"""全国趋势查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
class NationalTrendTool(BaseRateTool):
"""查询全国在线率趋势的工具"""
name = "national_online_trend"
description = "查询全国范围内设备在线率的变化趋势"
args_schema: Type[BaseModel] = NationalTrendArgs
def _run(self, start_time: str, end_time: str) -> Dict[str, Any]:
return self.get_national_trend(start_time, end_time)
def get_national_trend(self, start_time: str, end_time: str) -> Dict[str, Any]:
"""
接口三:获取全国在线率趋势
"""
# 准备数据
data = {
}
return self.format_response(data, fig)
\ No newline at end of file
......@@ -18,13 +18,37 @@ app.add_middleware(
allow_headers=["*"], # 允许所有HTTP头
)
global base_llm
global base_llm, tool_base_url
base_llm = None
tool_base_url = None
class AgentManager:
def __init__(self):
self.llm = None
self.agent = None
def initialize(self, api_key: str, api_base: str, model_name: str, tool_base_url: str):
self.llm = ChatOpenAI(
openai_api_key=api_key,
openai_api_base=api_base,
model_name=model_name,
verbose=True
)
self.agent = new_rate_agent(self.llm,verbose=True,tool_base_url=tool_base_url)
def get_llm(self):
return self.llm
def get_agent(self):
return self.agent
agent_manager = AgentManager()
@app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
agent = new_rate_agent(base_llm,verbose=True)
agent = agent_manager.get_agent()
try:
res = agent.exec(prompt_args={"input": chat_request.query})
except Exception as e:
......@@ -38,6 +62,8 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
'data': res
}
if __name__ == "__main__":
# 参数解析
parser = argparse.ArgumentParser(description="启动API服务")
......@@ -45,13 +71,14 @@ if __name__ == "__main__":
parser.add_argument("--host", type=str, default='0.0.0.0', help="API服务地址")
parser.add_argument("--llm", type=str, default='Qwen2-7B', help="API服务地址")
parser.add_argument("--api_base", type=str, default='http://192.168.10.14:8000/v1', help="API服务地址")
parser.add_argument("--tool_base_url", type=str, default='http://localhost:5001', help="API服务地址")
args = parser.parse_args()
base_llm = ChatOpenAI(
openai_api_key='xxxxxxxxxxxxx',
openai_api_base=args.api_base,
agent_manager.initialize(
api_key='xxxxxxxxxxxxx',
api_base=args.api_base,
model_name=args.llm,
verbose=True
tool_base_url=args.tool_base_url
)
uvicorn.run(app, host=args.host, port=args.port)
......@@ -14,7 +14,7 @@ from langchain.agents.format_scratchpad.openai_tools import (
from langchain import hub
from src.agent.tool_rate import RegionRateTool,RankingRateTool,NationalTrendTool
from src.agent.tool_rate import RegionRateTool,RankingRateTool
from src.agent.tool_monitor import MonitorPointTool
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
......@@ -78,7 +78,7 @@ ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及
注意事项:
- 时间格式统一使用:YYYY-MM-DD
- 地区名称需要包含行政级别(如:福建省、厦门市)
- 数据展示优先使用markdown 格式的数据表格,并配合文字说明
- 数据展示优先使用 markdown 格式的数据表格,并配合文字说明
- 百分比数据保留两位小数
您可以使用以下工具:
......@@ -119,7 +119,7 @@ Action:
你的回复格式为 Action:```$JSON_BLOB```然后 Observation。
"""
PROMPT_AGENT_HUMAN = """{input}\n\n {agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复。你的主要目标是帮助用户快速理解和分析设备在线率数据,提供准确、直观的分析结果)"""
PROMPT_AGENT_HUMAN = """{input}\n\n {agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复。工具返回的数据必须使用表格展示,包含在最终输出中)"""
PROMPT_AGENT_SYS_VARS = [ "tool_names", "tools"]
class RateAgentV2:
......@@ -143,11 +143,15 @@ class RateAgentV2:
def new_rate_agent(llm, verbose: bool = False,**args):
if args['tool_base_url']:
tool_base_url = args['tool_base_url']
else:
tool_base_url = const_base_url
tools = [
RegionRateTool(),
RankingRateTool(),
NationalTrendTool(),
MonitorPointTool()
RegionRateTool(base_url=tool_base_url),
RankingRateTool(base_url=tool_base_url),
MonitorPointTool(base_url=tool_base_url)
]
# 使用 LangChain 的工具调用代理
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment