Commit df832ca9 by tinywell

工具调试

parent 23a22fdc
...@@ -7,7 +7,8 @@ from urllib.parse import urljoin ...@@ -7,7 +7,8 @@ from urllib.parse import urljoin
# 泛型类型定义 # 泛型类型定义
T = TypeVar('T') T = TypeVar('T')
const_base_url = "http://172.30.0.37:30007" # const_base_url = "http://172.30.0.37:30007"
const_base_url = "http://localhost:5001"
const_url_point = "/cigem/getMonitorPointAll" const_url_point = "/cigem/getMonitorPointAll"
const_url_rate = "/cigem/getAvgOnlineRate" const_url_rate = "/cigem/getAvgOnlineRate"
const_url_rate_ranking = "/cigem/getOnlineRateRank" const_url_rate_ranking = "/cigem/getOnlineRateRank"
......
...@@ -62,12 +62,12 @@ class BaseRateTool(BaseTool): ...@@ -62,12 +62,12 @@ class BaseRateTool(BaseTool):
class RegionRateArgs(BaseModel): class RegionRateArgs(BaseModel):
"""地区在线率查询参数""" """地区在线率查询参数"""
region_name: str = Field(..., description="地区名称") region_name: str = Field(..., description="地区名称,如果要查询全国数据,请输入空字符串")
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)") start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)") end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
class RegionRateTool(BaseRateTool): class RegionRateTool(BaseRateTool):
"""查询特定地区设备在线率的工具""" """查询全国或者特定地区设备在线率的工具"""
name = "region_online_rate" name = "region_online_rate"
description = "查询指定地区在指定时间段内的设备在线率" description = "查询指定地区在指定时间段内的设备在线率"
args_schema: Type[BaseModel] = RegionRateArgs args_schema: Type[BaseModel] = RegionRateArgs
...@@ -77,27 +77,33 @@ class RegionRateTool(BaseRateTool): ...@@ -77,27 +77,33 @@ class RegionRateTool(BaseRateTool):
super().__init__(**data) super().__init__(**data)
self.client = RateClient(base_url=base_url) self.client = RateClient(base_url=base_url)
def _run(self, region_name: str, start_time: str, end_time: str) -> Dict[str, Any]: def _run(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]:
return self.get_region_online_rate(region_name, start_time, end_time) return self.get_region_online_rate(start_time, end_time, region_name)
def get_region_online_rate(self, region_name: str, start_time: str, end_time: str) -> Dict[str, Any]: def get_region_online_rate(self, start_time: str, end_time: str, region_name: str="") -> Dict[str, Any]:
# 查询数据 # 查询数据
code = code_tool.find_code(region_name) print(f"查询地区在线率: {region_name}, 时间范围: {start_time} 至 {end_time}")
if not code: code = ""
if region_name != "":
codes = code_tool.find_code(region_name)
if codes is None or len(codes) == 0:
return { return {
'code': 400, 'code': 400,
'message': f'未找到匹配的区域代码: {region_name}' 'message': f'未找到匹配的区域代码: {region_name}'
} }
df = self.client.query_rates_sync({ code = codes[0][1]
'startDate': start_time, df = self.client.query_rates_sync(code, start_time, end_time)
'endDate': end_time, print(f"地区在线率接口调用结果: {df}")
'areaCode': code[0][1]
})
# 准备数据 # 准备数据
if df.type != 1 or len(df.resultdata) == 0:
return {
'code': 400,
'message': f'未找到{region_name}在{start_time}至{end_time}期间的数据,请检查是否有相关数据权限'
}
print(f"地区在线率查询结果: {df.resultdata}")
data = { data = {
'region': region_name, 'region': region_name,
'region_code': code[0][1], 'region_code': code,
'rate_data': df.resultdata, 'rate_data': df.resultdata,
'date_range': { 'date_range': {
'start': start_time, 'start': start_time,
...@@ -136,6 +142,12 @@ class RankingRateTool(BaseRateTool): ...@@ -136,6 +142,12 @@ class RankingRateTool(BaseRateTool):
df = self.client.query_rates_ranking_sync(rank_type=1) df = self.client.query_rates_ranking_sync(rank_type=1)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10] # df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据 # 准备数据
if df.type != 1 or len(df.resultdata) == 0:
return {
'code': 400,
'message': f'未找到省份在线率排名数据,请检查是否有相关数据权限'
}
print(f"省份在线率排名数据: {df.resultdata}")
data = { data = {
'rankings': df.resultdata, 'rankings': df.resultdata,
'total_provinces': len(df.resultdata), 'total_provinces': len(df.resultdata),
...@@ -150,9 +162,14 @@ class RankingRateTool(BaseRateTool): ...@@ -150,9 +162,14 @@ class RankingRateTool(BaseRateTool):
def _get_manufacturer_ranking(self) -> Dict[str, Any]: def _get_manufacturer_ranking(self) -> Dict[str, Any]:
"""获取厂商在线率排名""" """获取厂商在线率排名"""
df = self.client.query_rates_ranking_sync(rank_type=2) df = self.client.query_rates_ranking_sync(rank_type=2)
print("厂商数据:", df.resultdata)
# df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10] # df.resultdata = df.resultdata if len(df.resultdata) < 10 else df.resultdata[:10]
# 准备数据 # 准备数据
if df.type != 1 or len(df.resultdata) == 0:
return {
'code': 400,
'message': f'未找到厂商在线率排名数据,请检查是否有相关数据权限'
}
print(f"厂商在线率排名数据: {df.resultdata}")
data = { data = {
'rankings': df.resultdata, 'rankings': df.resultdata,
'total_manufacturers': len(df.resultdata), 'total_manufacturers': len(df.resultdata),
...@@ -163,29 +180,3 @@ class RankingRateTool(BaseRateTool): ...@@ -163,29 +180,3 @@ class RankingRateTool(BaseRateTool):
} }
return data return data
\ No newline at end of file
class NationalTrendArgs(BaseModel):
"""全国趋势查询参数"""
start_time: str = Field(..., description="开始时间 (YYYY-MM-DD)")
end_time: str = Field(..., description="结束时间 (YYYY-MM-DD)")
class NationalTrendTool(BaseRateTool):
"""查询全国在线率趋势的工具"""
name = "national_online_trend"
description = "查询全国范围内设备在线率的变化趋势"
args_schema: Type[BaseModel] = NationalTrendArgs
def _run(self, start_time: str, end_time: str) -> Dict[str, Any]:
return self.get_national_trend(start_time, end_time)
def get_national_trend(self, start_time: str, end_time: str) -> Dict[str, Any]:
"""
接口三:获取全国在线率趋势
"""
# 准备数据
data = {
}
return self.format_response(data, fig)
\ No newline at end of file
...@@ -18,13 +18,37 @@ app.add_middleware( ...@@ -18,13 +18,37 @@ app.add_middleware(
allow_headers=["*"], # 允许所有HTTP头 allow_headers=["*"], # 允许所有HTTP头
) )
global base_llm global base_llm, tool_base_url
base_llm = None base_llm = None
tool_base_url = None
class AgentManager:
def __init__(self):
self.llm = None
self.agent = None
def initialize(self, api_key: str, api_base: str, model_name: str, tool_base_url: str):
self.llm = ChatOpenAI(
openai_api_key=api_key,
openai_api_base=api_base,
model_name=model_name,
verbose=True
)
self.agent = new_rate_agent(self.llm,verbose=True,tool_base_url=tool_base_url)
def get_llm(self):
return self.llm
def get_agent(self):
return self.agent
agent_manager = AgentManager()
@app.post('/api/agent/rate') @app.post('/api/agent/rate')
def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)): def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
agent = new_rate_agent(base_llm,verbose=True) agent = agent_manager.get_agent()
try: try:
res = agent.exec(prompt_args={"input": chat_request.query}) res = agent.exec(prompt_args={"input": chat_request.query})
except Exception as e: except Exception as e:
...@@ -38,6 +62,8 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)): ...@@ -38,6 +62,8 @@ def rate(chat_request: GeoAgentRateRequest, token: str = Header(None)):
'data': res 'data': res
} }
if __name__ == "__main__": if __name__ == "__main__":
# 参数解析 # 参数解析
parser = argparse.ArgumentParser(description="启动API服务") parser = argparse.ArgumentParser(description="启动API服务")
...@@ -45,13 +71,14 @@ if __name__ == "__main__": ...@@ -45,13 +71,14 @@ if __name__ == "__main__":
parser.add_argument("--host", type=str, default='0.0.0.0', help="API服务地址") parser.add_argument("--host", type=str, default='0.0.0.0', help="API服务地址")
parser.add_argument("--llm", type=str, default='Qwen2-7B', help="API服务地址") parser.add_argument("--llm", type=str, default='Qwen2-7B', help="API服务地址")
parser.add_argument("--api_base", type=str, default='http://192.168.10.14:8000/v1', help="API服务地址") parser.add_argument("--api_base", type=str, default='http://192.168.10.14:8000/v1', help="API服务地址")
parser.add_argument("--tool_base_url", type=str, default='http://localhost:5001', help="API服务地址")
args = parser.parse_args() args = parser.parse_args()
base_llm = ChatOpenAI( agent_manager.initialize(
openai_api_key='xxxxxxxxxxxxx', api_key='xxxxxxxxxxxxx',
openai_api_base=args.api_base, api_base=args.api_base,
model_name=args.llm, model_name=args.llm,
verbose=True tool_base_url=args.tool_base_url
) )
uvicorn.run(app, host=args.host, port=args.port) uvicorn.run(app, host=args.host, port=args.port)
...@@ -14,7 +14,7 @@ from langchain.agents.format_scratchpad.openai_tools import ( ...@@ -14,7 +14,7 @@ from langchain.agents.format_scratchpad.openai_tools import (
from langchain import hub from langchain import hub
from src.agent.tool_rate import RegionRateTool,RankingRateTool,NationalTrendTool from src.agent.tool_rate import RegionRateTool,RankingRateTool
from src.agent.tool_monitor import MonitorPointTool from src.agent.tool_monitor import MonitorPointTool
# def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None, # def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None,
...@@ -78,7 +78,7 @@ ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及 ...@@ -78,7 +78,7 @@ ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理地质监测点信息及
注意事项: 注意事项:
- 时间格式统一使用:YYYY-MM-DD - 时间格式统一使用:YYYY-MM-DD
- 地区名称需要包含行政级别(如:福建省、厦门市) - 地区名称需要包含行政级别(如:福建省、厦门市)
- 数据展示优先使用markdown 格式的数据表格,并配合文字说明 - 数据展示优先使用 markdown 格式的数据表格,并配合文字说明
- 百分比数据保留两位小数 - 百分比数据保留两位小数
您可以使用以下工具: 您可以使用以下工具:
...@@ -119,7 +119,7 @@ Action: ...@@ -119,7 +119,7 @@ Action:
你的回复格式为 Action:```$JSON_BLOB```然后 Observation。 你的回复格式为 Action:```$JSON_BLOB```然后 Observation。
""" """
PROMPT_AGENT_HUMAN = """{input}\n\n {agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复。你的主要目标是帮助用户快速理解和分析设备在线率数据,提供准确、直观的分析结果)""" PROMPT_AGENT_HUMAN = """{input}\n\n {agent_scratchpad}\n (请注意,无论如何都要以 JSON 对象回复。工具返回的数据必须使用表格展示,包含在最终输出中)"""
PROMPT_AGENT_SYS_VARS = [ "tool_names", "tools"] PROMPT_AGENT_SYS_VARS = [ "tool_names", "tools"]
class RateAgentV2: class RateAgentV2:
...@@ -143,11 +143,15 @@ class RateAgentV2: ...@@ -143,11 +143,15 @@ class RateAgentV2:
def new_rate_agent(llm, verbose: bool = False,**args): def new_rate_agent(llm, verbose: bool = False,**args):
if args['tool_base_url']:
tool_base_url = args['tool_base_url']
else:
tool_base_url = const_base_url
tools = [ tools = [
RegionRateTool(), RegionRateTool(base_url=tool_base_url),
RankingRateTool(), RankingRateTool(base_url=tool_base_url),
NationalTrendTool(), MonitorPointTool(base_url=tool_base_url)
MonitorPointTool()
] ]
# 使用 LangChain 的工具调用代理 # 使用 LangChain 的工具调用代理
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment