from typing import Type from pydantic import BaseModel, Field from langchain_core.tools import BaseTool class ChartArgs(BaseModel): title: str = Field(..., description="图表名称,用于显示在图表上方的字符串") chart_type: str = Field(..., description="图表类型,如 line, bar, scatter, pie 等") x: list = Field(..., description="x 轴数据,列表形式") y: list = Field(..., description="y 轴数据,列表形式") x_label: str = Field(..., description="x 轴标签") y_label: str = Field(..., description="y 轴标签") class Chart(BaseTool): name = "chart" description = "组装生成图表的中间数据。此工具生成的数据需要保存到 action_cache 中。键值为 'chart_data'。" args_schema: Type[BaseModel] = ChartArgs def _run( self, title: str,chart_type: str, x: list, y: list, x_label: str, y_label: str ) -> str: """Use the tool.""" result = { "title": title, "chart_type": chart_type, "x": x, "y": y, "x_label": x_label, "y_label": y_label, } return result # 生成图表 def chart_image(chart_data): """ 生成图表 Args: chart_data: dict 图表数据 { "title": str, 图表名称 "chart_type": str, 图表类型,如 line, bar, scatter, pie 等 "x": list, x 轴数据,列表形式 "y": list, y 轴数据,列表形式 "x_label": str, x 轴标签 "y_label": str, y 轴标签 } Returns: PIL Image 图表图片 """ import matplotlib.pyplot as plt plt.figure(figsize=(10, 6)) match chart_data["chart_type"]: case "line": plt.plot(chart_data["x"], chart_data["y"]) case "bar": plt.bar(chart_data["x"], chart_data["y"]) case "scatter": plt.scatter(chart_data["x"], chart_data["y"]) case "pie": plt.pie(chart_data["y"], labels=chart_data["x"], autopct="%1.1f%%") case _: raise ValueError("Invalid chart type") plt.xlabel(chart_data["x_label"]) plt.ylabel(chart_data["y_label"]) plt.title(chart_data["name"]) # plt.show() from io import BytesIO buf = BytesIO() plt.savefig(buf, format="png") from PIL import Image image = Image.open(buf) # image.show() return image