COLE / predictions /plot_results.py
Yurhu's picture
Initial snapshot upload
75ec748 verified
import math
import matplotlib.pyplot as plt
import numpy as np
from predictions.convert_results_to_md import results_data
def extract_plot_task_data(all_tasks_data, task):
task_data = all_tasks_data[task]
models = []
task_metrics = {}
for model, model_metrics in task_data.items():
models.append(model)
for metric, results in model_metrics.items():
if metric in task_metrics:
task_metrics[metric].append(results)
else:
task_metrics[metric] = [results]
return models, task_metrics
def extract_data_by_tasks():
all_tasks_data = {}
datas = results_data()
for data in datas:
model_name = data["model_name"]
for task in data["tasks"]:
for task_name, metrics_values in task.items():
for metric_values in metrics_values.values():
if task_name not in all_tasks:
all_tasks_data[task_name] = {
model_name: {
metric_key: metric_value
for metric_key, metric_value in metric_values.items()
if isinstance(metric_value, (int, float))
}
}
else:
all_tasks[task_name][model_name] = {
metric_key: metric_value
for metric_key, metric_value in metric_values.items()
if isinstance(metric_value, (int, float))
}
print(all_tasks)
return all_tasks
all_tasks = extract_data_by_tasks()
labels, metrics = extract_plot_task_data(all_tasks, "piaf")
print(metrics)
num_plots = len(all_tasks)
cols = 3 # or 2, 4 — adjust as needed
rows = math.ceil(num_plots / cols)
fig, axs = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4))
axs = axs.flatten() # flatten to 1D list for easier indexing
for i, dataset_name in enumerate(all_tasks.keys()):
labels, metrics = extract_plot_task_data(all_tasks, dataset_name)
x = np.arange(len(labels))
bar_width = 0.8 / len(metrics)
ax = axs[i] # Select the i-th subplot
for j, (key, values) in enumerate(metrics.items()):
offset = (j - len(metrics) / 2) * bar_width + bar_width / 2
ax.bar(x + offset, values, bar_width, label=key)
labels = [s.split("/")[-1] for s in labels]
ax.set_ylim(top=1)
ax.set_xlabel("Models")
ax.set_ylabel("Score")
ax.set_title(dataset_name)
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=-45)
ax.legend(fontsize=7)
# Hide unused subplots if any
for j in range(i + 1, len(axs)):
fig.delaxes(axs[j])
plt.tight_layout()
plt.show()