import sys,os sys.path.append("../") import pytest from unittest.mock import Mock from langchain_openai import ChatOpenAI from src.server.tool_picker import ToolPicker from src.agent.tool_rate import RegionRateTool,RankingRateTool from src.agent.tool_monitor import MonitorPointTool @pytest.fixture def mock_llm(): # llm = Mock(spec=ChatOpenAI) # # 模拟 LLM 返回结果 # llm.invoke.return_value = {"content": """{"tool": "RegionRateTool", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}}"""} # return llm llm = ChatOpenAI( openai_api_key="xxxxxx", openai_api_base="http://192.168.10.14:8000/v1", model_name="Qwen2-7B", verbose=True ) return llm # 使用参数化测试不同的场景 @pytest.mark.parametrize("query, expected_response, expected_tool", [ ( "请分析下今天全国各地区在线率情况", {"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}}, "region_online_rate" ), ( "请分析下今天甘肃省设备在线率情况", {"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": "甘肃省"}}, "region_online_rate" ), ( "查询今年三季度甘肃省设备在线率情况", {"tool": "region_online_rate", "params": {"start_time": "2024-07-01", "end_time": "2024-09-30", "region_name": "甘肃省"}}, "region_online_rate" ), ( "查询2024年11月13日各地区排名情况", {"tool": "online_rate_ranking", "params": {"rate_type": "1"}}, "online_rate_ranking" ), ( "查询各厂商在线率排名情况", {"tool": "online_rate_ranking", "params": {"rate_type": "2"}}, "online_rate_ranking" ), ( "甘肃省监控点的状态如何?", {"tool": "monitor_points_query", "params": {"key": "甘肃省"}}, "monitor_points_query" ), ]) def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool): # 创建测试用的工具 test_tools = [ RegionRateTool(), RankingRateTool(), MonitorPointTool() ] picker = ToolPicker(mock_llm, test_tools) result = picker.pick(query) # 验证结果 assert isinstance(result, dict) assert result["tool"] == expected_tool assert "params" in result