from typing import Type
from pydantic import BaseModel, Field
from langchain_core.tools import BaseTool

class ChartArgs(BaseModel):
    title: str = Field(..., description="图表名称,用于显示在图表上方的字符串")
    chart_type: str = Field(..., description="图表类型,如 line, bar, scatter, pie 等")
    x: list = Field(..., description="x 轴数据,列表形式")
    y: list = Field(..., description="y 轴数据,列表形式")
    x_label: str = Field(..., description="x 轴标签")
    y_label: str = Field(..., description="y 轴标签")

class Chart(BaseTool):
    name = "chart"
    description = "组装生成图表的中间数据。此工具生成的数据需要保存到 action_cache 中。键值为 'chart_data'。"
    args_schema: Type[BaseModel] = ChartArgs

    def _run(
        self, title: str,chart_type: str, x: list, y: list, x_label: str, y_label: str
    ) -> str:
        """Use the tool."""
        result = {
            "title": title,
            "chart_type": chart_type,
            "x": x,
            "y": y,
            "x_label": x_label,
            "y_label": y_label,
        }
        return result


# 生成图表
def chart_image(chart_data):
    """
    生成图表

    Args:
        chart_data: dict 图表数据 
            {
                "title": str, 图表名称
                "chart_type": str, 图表类型,如 line, bar, scatter, pie 等
                "x": list, x 轴数据,列表形式
                "y": list, y 轴数据,列表形式
                "x_label": str, x 轴标签
                "y_label": str, y 轴标签
            }

    Returns:
        PIL Image 图表图片
    """
    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 6))
    match chart_data["chart_type"]:
        case "line":
            plt.plot(chart_data["x"], chart_data["y"])
        case "bar":
            plt.bar(chart_data["x"], chart_data["y"])
        case "scatter":
            plt.scatter(chart_data["x"], chart_data["y"])
        case "pie":
            plt.pie(chart_data["y"], labels=chart_data["x"], autopct="%1.1f%%")
        case _:
            raise ValueError("Invalid chart type")
        
    plt.xlabel(chart_data["x_label"])
    plt.ylabel(chart_data["y_label"])
    plt.title(chart_data["name"])
    # plt.show()

    from io import BytesIO
    
    buf = BytesIO()
    plt.savefig(buf, format="png")

    from PIL import Image
    
    image = Image.open(buf)
    # image.show()
    return image