test_tool_picker.py 2.45 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
import sys,os
sys.path.append("../")

import pytest
from unittest.mock import Mock

from langchain_openai import ChatOpenAI

from src.server.tool_picker import ToolPicker
from src.agent.tool_rate import RegionRateTool,RankingRateTool
from src.agent.tool_monitor import MonitorPointTool

@pytest.fixture
def mock_llm():
    # llm = Mock(spec=ChatOpenAI)
    # # 模拟 LLM 返回结果
    # llm.invoke.return_value = {"content": """{"tool": "RegionRateTool", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}}"""}
    # return llm

    llm = ChatOpenAI(
        openai_api_key="xxxxxx",
        openai_api_base="http://192.168.10.14:8000/v1",
        model_name="Qwen2-7B",
        verbose=True
    )
    return llm

# 使用参数化测试不同的场景
@pytest.mark.parametrize("query, expected_response, expected_tool", [
    (
        "请分析下今天全国各地区在线率情况",
        {"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": ""}},
        "region_online_rate"
    ),
    (
        "请分析下今天甘肃省设备在线率情况",
        {"tool": "region_online_rate", "params": {"start_time": "2024-11-13", "end_time": "2024-11-13", "region_name": "甘肃省"}},
        "region_online_rate"
    ),
    (
        "查询今年三季度甘肃省设备在线率情况",
        {"tool": "region_online_rate", "params": {"start_time": "2024-07-01", "end_time": "2024-09-30", "region_name": "甘肃省"}},
        "region_online_rate"
    ),
    (
        "查询2024年11月13日各地区排名情况",
        {"tool": "online_rate_ranking", "params": {"rate_type": "1"}},
        "online_rate_ranking"
    ),
    (
        "查询各厂商在线率排名情况",
        {"tool": "online_rate_ranking", "params": {"rate_type": "2"}},
        "online_rate_ranking"
    ),
    (
        "甘肃省监控点的状态如何?",
        {"tool": "monitor_points_query", "params": {"key": "甘肃省"}},
        "monitor_points_query"
    ),
])

def test_tool_picker_scenarios(mock_llm, query, expected_response, expected_tool):
    
    # 创建测试用的工具
    test_tools = [
        RegionRateTool(),
        RankingRateTool(),
        MonitorPointTool()
    ]
    
    picker = ToolPicker(mock_llm, test_tools)
    result = picker.pick(query)
    
    # 验证结果
    assert isinstance(result, dict)
    assert result["tool"] == expected_tool
    assert "params" in result