diff --git a/src/agent/tool_chart.py b/src/agent/tool_chart.py index 74d5b65..c1f3a10 100644 --- a/src/agent/tool_chart.py +++ b/src/agent/tool_chart.py @@ -16,15 +16,16 @@ class Chart(BaseTool): args_schema: Type[BaseModel] = ChartArgs def _run( - self, name: str, x: list, y: list, x_label: str, y_label: str + self, title: str,chart_type: str, x: list, y: list, x_label: str, y_label: str ) -> str: """Use the tool.""" result = { - "name": name, + "title": title, + "chart_type": chart_type, "x": x, "y": y, "x_label": x_label, - "y_label": y_label + "y_label": y_label, } return result diff --git a/src/agent/tool_monitor.py b/src/agent/tool_monitor.py new file mode 100644 index 0000000..9c77b16 --- /dev/null +++ b/src/agent/tool_monitor.py @@ -0,0 +1,90 @@ +from typing import Dict, Any, Optional +from pydantic import BaseModel, Field +from langchain_core.tools import BaseTool +from src.server.http_tools import MonitorClient + +class MonitorPointArgs(BaseModel): + """监测点查询参数""" + key: str = Field(..., description="行政区划名称(省/市/区县级别均可)") + +class MonitorPointTool(BaseTool): + """查询监测点信息的工具""" + name = "monitor_points_query" + description = """查询指定行政区划的监测点信息。 + 可以查询任意省/市/区县级别的监测点数据。 + 输入参数为行政区划名称,如:湖南省、长沙市、岳麓区等。 + 返回该区域内的监测点列表,包含位置、经纬度等详细信息。 + """ + args_schema: type[BaseModel] = MonitorPointArgs + client: Any = Field(None, exclude=True) + + def __init__(self, base_url: str = "http://localhost:5001", **data): + """ + 初始化监测点查询工具 + + Args: + base_url: API服务器地址 + **data: 其他参数 + """ + super().__init__(**data) + self.client = MonitorClient(base_url=base_url) + + def _run(self, key: str) -> Dict[str, Any]: + """ + 执行监测点查询 + + Args: + key: 行政区划名称 + + Returns: + Dict: 包含查询结果的字典 + """ + try: + print(f"进入 monitor_points_query 工具, 查询监测点信息: {key}") + response = self.client.query_points_sync(key) + + # 格式化返回结果 + if not response.resultdata: + return { + "status": "success", + "message": f"未在{key}找到监测点信息", + "points": [] + } + + # 提取关键信息并格式化 + points_info = [] + for point in response.resultdata: + points_info.append({ + "名称": point.MONITORPOINTNAME, + "位置": point.LOCATION, + "经度": point.LONGITUDE, + "纬度": point.LATITUDE, + "海拔": point.ELEVATION, + "建设单位": point.BUILDUNIT, + "监测单位": point.MONITORUNIT + }) + + return { + "status": "success", + "message": f"在{key}找到{len(points_info)}个监测点", + "points": points_info + } + + except Exception as e: + return { + "status": "error", + "message": f"查询失败: {str(e)}", + "points": [] + } + + def _arun(self, key: str) -> Dict[str, Any]: + """ + 异步执行监测点查询 + + Args: + key: 行政区划名称 + + Returns: + Dict: 包含查询结果的字典 + """ + raise NotImplementedError("异步查询暂未实现") \ No newline at end of file diff --git a/src/mock/server.py b/src/mock/server.py new file mode 100644 index 0000000..56f3e78 --- /dev/null +++ b/src/mock/server.py @@ -0,0 +1,155 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import random +from typing import List, Optional, Dict + +app = FastAPI(title="Monitor Points API") + +# 多级行政区划数据结构 +area_data = { + "湖南省": { + "长沙市": ["芙蓉区", "天心区", "岳麓区", "开福区", "雨花区"], + "株洲市": ["天元区", "荷塘区", "芦淞区", "石峰区", "醴陵市"], + "湘潭市": ["雨湖区", "岳塘区", "湘乡市", "韶山市"], + }, + "湖北省": { + "武汉市": ["江岸区", "江汉区", "硚口区", "汉阳区", "武昌区"], + "宜昌市": ["西陵区", "伍家岗区", "点军区", "猇亭区"], + }, + "广东省": { + "广州市": ["越秀区", "海珠区", "荔湾区", "天河区", "白云区"], + "深圳市": ["福田区", "罗湖区", "南山区", "宝安区", "龙岗区"], + } +} + +# 用于生成随机地名的组件 +towns = ["茶山镇", "李畋镇", "泗汾镇", "东富镇", "新华镇", "龙泉镇"] +villages = ["大西垅村", "石桥村", "龙门村", "青山村", "凤凰村", "红星村"] + +# 定义请求和响应模型 +class QueryRequest(BaseModel): + key: str + +class MonitorPoint(BaseModel): + MONITORPOINTCODE: str + MONITORPOINTNAME: str + LOCATION: str + LATITUDE: str + LONGITUDE: str + ELEVATION: str + BUILDUNIT: str + MONITORUNIT: str + YWUNIT: str + SGDW: Optional[str] = None + MANUFACTURER: str = "" + +class QueryResponse(BaseModel): + type: int = 1 + resultcode: int + message: str = "" + resultdata: List[MonitorPoint] + +def find_area_info(key: str) -> tuple[Optional[str], Optional[str], Optional[str]]: + """ + 查找行政区划信息 + 返回: (省, 市, 区/县) + """ + # 移除可能包含的"省"、"市"、"区"等后缀 + search_key = key.replace("省", "").replace("市", "").replace("区", "").replace("县", "") + + # 在省级查找 + for province, cities in area_data.items(): + if search_key in province: + return province, None, None + # 在市级查找 + for city, districts in cities.items(): + if search_key in city: + return province, city, None + # 在区级查找 + for district in districts: + if search_key in district: + return province, city, district + + return None, None, None + +def generate_random_point(province: str, city: Optional[str] = None, district: Optional[str] = None) -> dict: + """生成随机的检测点数据,根据行政区划信息生成相应的位置""" + # 确定行政区划信息 + if not city: + city = random.choice(list(area_data[province].keys())) + if not district: + district = random.choice(area_data[province][city]) + + town = random.choice(towns) + village = random.choice(villages) + + # 根据不同省份设置不同的经纬度范围 + latitude_ranges = { + "湖南省": (24.6, 30.2), + "湖北省": (29.0, 33.2), + "广东省": (20.2, 25.5) + } + longitude_ranges = { + "湖南省": (108.8, 114.2), + "湖北省": (108.3, 116.1), + "广东省": (109.7, 117.3) + } + + lat_range = latitude_ranges.get(province, (20.0, 35.0)) + lng_range = longitude_ranges.get(province, (108.0, 118.0)) + + return { + "MONITORPOINTCODE": f"{random.randint(430000, 439999)}000{random.randint(100, 999)}", + "MONITORPOINTNAME": f"{city}{district}{town}{village}地质灾害隐患点", + "LOCATION": f"{province}{city}{district}{town}{village}", + "LATITUDE": f"{random.uniform(lat_range[0], lat_range[1]):.8f}", + "LONGITUDE": f"{random.uniform(lng_range[0], lng_range[1]):.8f}", + "ELEVATION": f"{random.uniform(30.0, 100.0):.4f}", + "BUILDUNIT": "省地质工程勘察院有限公司", + "MONITORUNIT": "省地质工程勘察院有限公司", + "YWUNIT": "致力工程科技有限公司", + "SGDW": None, + "MANUFACTURER": "" + } + +@app.post("/api/monitor/points", response_model=QueryResponse) +async def query_points(request: QueryRequest): + """ + 检测点查询接口 + + - **key**: 行政区划名称(省/市/区县级别均可) + """ + print(f"进入 query_points 接口, 查询监测点信息: {request.key}") + try: + province, city, district = find_area_info(request.key) + + if not province: + return QueryResponse( + type=1, + resultcode=1, + message="未找到匹配的行政区划", + resultdata=[] + ) + + # 生成1-5个随机检测点数据 + num_points = random.randint(1, 5) + points = [generate_random_point(province, city, district) for _ in range(num_points)] + response = QueryResponse( + type=1, + resultcode=1, + message="", + resultdata=points + ) + print(f"查询监测点信息成功, 返回结果: {response}") + return response + + except Exception as e: + print(f"查询监测点信息失败, 错误信息: {str(e)}") + raise HTTPException( + status_code=500, + detail=str(e) + ) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=5001) diff --git a/src/server/agent_rate.py b/src/server/agent_rate.py index f1eaef3..ee5cf23 100644 --- a/src/server/agent_rate.py +++ b/src/server/agent_rate.py @@ -16,6 +16,7 @@ from langchain import hub from src.agent.tool_rate import RegionRateTool,RankingRateTool,NationalTrendTool from src.agent.fake_data_rate import MockDBConnection +from src.agent.tool_monitor import MonitorPointTool def create_rate_agent(llm, tools: List[BaseTool],prompt: PromptTemplate = None, tools_renderer: ToolsRenderer = render_text_description_and_args, @@ -64,6 +65,7 @@ class RateAgent: yield step +# 适配 structured_chat_agent 的 prompt ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理设备在线率分析的AI助手。你可以通过调用专门的工具来分析和展示不同维度的在线率数据。 你可以处理以下三类核心任务: @@ -88,13 +90,13 @@ ONLINE_RATE_SYSTEM_PROMPT = """你是一个专门处理设备在线率分析的A 1. 理解用户意图,将用户问题映射到合适的分析类型 2. 确保必要参数完整,如果缺少参数要主动询问 3. 调用相应的分析工具获取数据 -4. 生成清晰的分析报告,包括数据解读和可视化图表 +4. 生成清晰的分析报告,包括数据解读和markdown 格式的数据表格 5. 对异常情况(如数据缺失、参数错误)提供友好的解释和建议 注意事项: - 时间格式统一使用:YYYY-MM-DD - 地区名称需要包含行政级别(如:福建省、厦门市) -- 数据展示优先使用图表,并配合文字说明 +- 数据展示优先使用markdown 格式的数据表格,并配合文字说明 - 百分比数据保留两位小数 您可以使用以下工具: @@ -180,8 +182,8 @@ PROMPT_AGENT_SYS_VARS = [ "tool_names", "tools"] class RateAgentV2: def __init__(self, llm, tools: List[BaseTool],prompt: PromptTemplate = None, verbose: bool = False,**args): prompt = ChatPromptTemplate.from_messages([ - SystemMessagePromptTemplate.from_template(ONLINE_RATE_SYSTEM_PROMPT), - MessagesPlaceholder(variable_name="chat_history", optional=True), + SystemMessagePromptTemplate.from_template(ONLINE_RATE_SYSTEM_PROMPT), + MessagesPlaceholder(variable_name="chat_history", optional=True), HumanMessagePromptTemplate.from_template(PROMPT_AGENT_HUMAN) ]) @@ -202,18 +204,10 @@ def new_rate_agent(llm, verbose: bool = False,**args): tools = [ RegionRateTool(db_connection=conn), RankingRateTool(db_connection=conn), - NationalTrendTool(db_connection=conn) + NationalTrendTool(db_connection=conn), + MonitorPointTool() ] - # prompt = ChatPromptTemplate.from_messages([ - # SystemMessagePromptTemplate.from_template(ONLINE_RATE_SYSTEM_PROMPT), - # MessagesPlaceholder(variable_name="chat_history", optional=True), - # HumanMessagePromptTemplate.from_template(PROMPT_AGENT_HUMAN) - # ]) - - # prompt = prompt.partial(tools=render_text_description_and_args(tools), tool_names=", ".join([t.name for t in tools])) - - # 使用 LangChain 的工具调用代理 agent = RateAgentV2(llm=llm, tools=tools, verbose=verbose, **args) return agent diff --git a/src/server/http_tools.py b/src/server/http_tools.py new file mode 100644 index 0000000..9a688d9 --- /dev/null +++ b/src/server/http_tools.py @@ -0,0 +1,115 @@ +import httpx +from typing import List, Optional, Dict +from pydantic import BaseModel +import asyncio +from urllib.parse import urljoin + +class MonitorPoint(BaseModel): + MONITORPOINTCODE: str + MONITORPOINTNAME: str + LOCATION: str + LATITUDE: str + LONGITUDE: str + ELEVATION: str + BUILDUNIT: str + MONITORUNIT: str + YWUNIT: str + SGDW: Optional[str] = None + MANUFACTURER: str = "" + +class QueryResponse(BaseModel): + type: int + resultcode: int + message: str + resultdata: List[MonitorPoint] + +class MonitorClient: + def __init__(self, base_url: str = "http://localhost:5001"): + """ + 初始化监测点查询客户端 + + Args: + base_url: API服务器基础URL + """ + self.base_url = base_url.rstrip('/') + self.timeout = 30.0 + + async def query_points(self, key: str) -> QueryResponse: + """ + 异步查询监测点信息 + + Args: + key: 行政区划关键字(省/市/区县级别均可) + + Returns: + QueryResponse: 查询响应对象 + + Raises: + httpx.HTTPError: 当HTTP请求失败时 + ValueError: 当响应数据格式不正确时 + """ + async with httpx.AsyncClient(timeout=self.timeout) as client: + url = urljoin(self.base_url, "/api/monitor/points") + response = await client.post(url, json={"key": key}) + response.raise_for_status() + return QueryResponse(**response.json()) + + def query_points_sync(self, key: str) -> QueryResponse: + """ + 同步查询监测点信息 + + Args: + key: 行政区划关键字(省/市/区县级别均可) + + Returns: + QueryResponse: 查询响应对象 + + Raises: + httpx.HTTPError: 当HTTP请求失败时 + ValueError: 当响应数据格式不正确时 + """ + with httpx.Client(timeout=self.timeout) as client: + url = urljoin(self.base_url, "/api/monitor/points") + response = client.post(url, json={"key": key}) + response.raise_for_status() + return QueryResponse(**response.json()) + +# 使用示例 +async def example_async_usage(): + client = MonitorClient() + try: + # 异步查询示例 + response = await client.query_points("湖南") + if response.resultcode == 1 and response.resultdata: + for point in response.resultdata: + print(f"监测点名称: {point.MONITORPOINTNAME}") + print(f"位置: {point.LOCATION}") + print(f"经纬度: {point.LONGITUDE}, {point.LATITUDE}") + print("---") + except httpx.HTTPError as e: + print(f"HTTP请求错误: {e}") + except Exception as e: + print(f"发生错误: {e}") + +def example_sync_usage(): + client = MonitorClient() + try: + # 同步查询示例 + response = client.query_points_sync("长沙") + if response.resultcode == 1 and response.resultdata: + for point in response.resultdata: + print(f"监测点名称: {point.MONITORPOINTNAME}") + print(f"位置: {point.LOCATION}") + print(f"经纬度: {point.LONGITUDE}, {point.LATITUDE}") + print("---") + except httpx.HTTPError as e: + print(f"HTTP请求错误: {e}") + except Exception as e: + print(f"发生错误: {e}") + +if __name__ == "__main__": + # 异步调用示例 + asyncio.run(example_async_usage()) + + # 同步调用示例 + example_sync_usage()