test_tool_picker.py 5.33 KB
Newer Older
1 2 3 4 5 6 7 8 9
import sys,os
sys.path.append("../")

import pytest
from unittest.mock import Mock

from langchain_openai import ChatOpenAI

from src.server.tool_picker import ToolPicker
10
from src.agent.tool_rate import RegionRateTool,RankingRateTool
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
from src.agent.tool_monitor import MonitorPointTool

@pytest.fixture
def mock_llm():
    # llm = Mock(spec=ChatOpenAI)
    # # 模拟 LLM 返回结果
    # llm.invoke.return_value = {"content": """{"tool": "RegionRateTool", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}}"""}
    # return llm

    llm = ChatOpenAI(
        openai_api_key="xxxxxx",
        openai_api_base="http://192.168.10.14:8000/v1",
        model_name="Qwen2-7B",
        verbose=True
    )
    return llm

# 使用参数化测试不同的场景
@pytest.mark.parametrize("query, expected_response, expected_tool", [
    (
        "请分析下今天全国各地区在线率情况",
32
        {"tool": "region_online_rate", "params": {"start_time": "2024-11-19", "end_time": "2024-11-19", "region_name": "", "month_required": False}},
33 34 35 36
        "region_online_rate"
    ),
    (
        "请分析下今天甘肃省设备在线率情况",
37
        {"tool": "region_online_rate", "params": {"start_time": "2024-11-19", "end_time": "2024-11-19", "region_name": "甘肃省","month_required":False }},
38 39 40 41
        "region_online_rate"
    ),
    (
        "查询今年三季度甘肃省设备在线率情况",
42
        {"tool": "region_online_rate", "params": {"start_time": "2024-07-01", "end_time": "2024-09-30", "region_name": "甘肃省","month_required":False }},
43 44 45 46
        "region_online_rate"
    ),
    (
        "查询2024年11月13日各地区排名情况",
47
        {"tool": "online_rate_ranking", "params": {"rate_type": 1}},
48 49 50 51
        "online_rate_ranking"
    ),
    (
        "查询各厂商在线率排名情况",
52
        {"tool": "online_rate_ranking", "params": {"rate_type": 2}},
53 54 55 56 57 58 59
        "online_rate_ranking"
    ),
    (
        "甘肃省监控点的状态如何?",
        {"tool": "monitor_points_query", "params": {"key": "甘肃省"}},
        "monitor_points_query"
    ),
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    (
        "查询2024年甘肃省各月在线率",
        {"tool": "month_online_rate", "params": {"start_time": "2024-01-01", "end_time": "2024-12-31", "region_name": "甘肃省", "month_required": True}},
        "region_online_rate"
    ),
    (
        "2024年10月15日,成都市武侯区的设备在线率是多少?",
        {"tool": "region_online_rate", "params": {"start_time": "2024-10-15", "end_time": "2024-10-15", "region_name": "成都市武侯区", "month_required": False}},
        "region_online_rate"
    ),
    (
        "2024年,成都市武侯区的设备在线率是多少?",
        {"tool": "region_online_rate", "params": {"start_time": "2024-01-01",  "region_name": "成都市武侯区", "month_required": False}},
        "region_online_rate"
    ),
    (
        "2023年甘肃省每月的设备在线率分别是多少?",
        {"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "甘肃省", "month_required": True}},
        "region_online_rate"
    ),
    (
        "2023年甘肃省按月统计设备在线率?",
        {"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "region_name": "甘肃省", "month_required": True}},
        "region_online_rate"
    ),
    (
        "2023年全国每个月的设备在线率",
        {"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "region_name": "", "month_required": True}},
        "region_online_rate"
    ),
    (
        "2023年1月-2023年12月期间西藏实验点在线率是多少?",
        {"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": False}},
        "region_online_rate"
    ),
    (
        "2023年1月-2023年12月期间西藏实验点各月在线率是多少?",
        {"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True}},
        "region_online_rate"
    ),
    (
        "2022年各个月设备在线率统计;",
        {"tool": "region_online_rate", "params": {"start_time": "2022-01-01", "end_time": "2022-12-31", "region_name": "", "month_required": True}},
        "region_online_rate"
    ),
    (
        "2023年1月-2023年12月期间西藏实验点每个月在线率是多少?",
        {"tool": "region_online_rate", "params": {"start_time": "2023-01-01", "end_time": "2023-12-31", "region_name": "西藏", "month_required": True}},
        "region_online_rate"
    )
110 111 112 113 114 115 116 117
])

def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool):
    
    # 创建测试用的工具
    test_tools = [
        RegionRateTool(),
        RankingRateTool(),
118
        MonitorPointTool(),
119 120 121 122
    ]
    
    picker = ToolPicker(mock_llm, test_tools)
    result = picker.pick(query)
123
    # print(f"query: {query},  result: {result}")
124 125 126 127
    # 验证结果
    assert isinstance(result, dict)
    assert result["tool"] == expected_tool
    assert "params" in result
128 129 130 131
    for key, value in expected_response["params"].items():
        print(f"key: {key}, value: {value} , result_value: {result['params'][key]}")
        if key == "month_required" and value == True:
            assert result["params"][key] == value