from smolagents import Tool | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
class DataGraphTool(Tool): | |
name = "data_graph" | |
description = "Create graphs from tabular data for better explainability." | |
inputs = { | |
"csv_path": {"type": "string", "description": "Path to a CSV file containing data."} | |
} | |
output_type = "string" | |
def forward(self, csv_path: str) -> str: | |
df = pd.read_csv(csv_path) | |
plt.figure(figsize=(10, 6)) | |
df.plot() | |
plot_path = "graph_output.png" | |
plt.savefig(plot_path) | |
return plot_path | |