GraphiqueAcademia / data_graph_tool.py
AxDutta's picture
Update data_graph_tool.py
eb578ca verified
raw
history blame contribute delete
584 Bytes
from smolagents import Tool
import pandas as pd
import matplotlib.pyplot as plt
class DataGraphTool(Tool):
name = "data_graph"
description = "Create graphs from tabular data for better explainability."
inputs = {
"csv_path": {"type": "string", "description": "Path to a CSV file containing data."}
}
output_type = "string"
def forward(self, csv_path: str) -> str:
df = pd.read_csv(csv_path)
plt.figure(figsize=(10, 6))
df.plot()
plot_path = "graph_output.png"
plt.savefig(plot_path)
return plot_path