|
import math
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from predictions.convert_results_to_md import results_data
|
|
|
|
|
|
def extract_plot_task_data(all_tasks_data, task):
|
|
task_data = all_tasks_data[task]
|
|
models = []
|
|
task_metrics = {}
|
|
for model, model_metrics in task_data.items():
|
|
models.append(model)
|
|
for metric, results in model_metrics.items():
|
|
if metric in task_metrics:
|
|
task_metrics[metric].append(results)
|
|
else:
|
|
task_metrics[metric] = [results]
|
|
|
|
return models, task_metrics
|
|
|
|
|
|
def extract_data_by_tasks():
|
|
all_tasks_data = {}
|
|
datas = results_data()
|
|
for data in datas:
|
|
model_name = data["model_name"]
|
|
for task in data["tasks"]:
|
|
for task_name, metrics_values in task.items():
|
|
for metric_values in metrics_values.values():
|
|
if task_name not in all_tasks:
|
|
all_tasks_data[task_name] = {
|
|
model_name: {
|
|
metric_key: metric_value
|
|
for metric_key, metric_value in metric_values.items()
|
|
if isinstance(metric_value, (int, float))
|
|
}
|
|
}
|
|
else:
|
|
all_tasks[task_name][model_name] = {
|
|
metric_key: metric_value
|
|
for metric_key, metric_value in metric_values.items()
|
|
if isinstance(metric_value, (int, float))
|
|
}
|
|
|
|
print(all_tasks)
|
|
return all_tasks
|
|
|
|
|
|
all_tasks = extract_data_by_tasks()
|
|
|
|
labels, metrics = extract_plot_task_data(all_tasks, "piaf")
|
|
print(metrics)
|
|
|
|
num_plots = len(all_tasks)
|
|
cols = 3
|
|
rows = math.ceil(num_plots / cols)
|
|
|
|
fig, axs = plt.subplots(rows, cols, figsize=(cols * 5, rows * 4))
|
|
axs = axs.flatten()
|
|
|
|
for i, dataset_name in enumerate(all_tasks.keys()):
|
|
labels, metrics = extract_plot_task_data(all_tasks, dataset_name)
|
|
x = np.arange(len(labels))
|
|
bar_width = 0.8 / len(metrics)
|
|
|
|
ax = axs[i]
|
|
|
|
for j, (key, values) in enumerate(metrics.items()):
|
|
offset = (j - len(metrics) / 2) * bar_width + bar_width / 2
|
|
ax.bar(x + offset, values, bar_width, label=key)
|
|
|
|
labels = [s.split("/")[-1] for s in labels]
|
|
ax.set_ylim(top=1)
|
|
ax.set_xlabel("Models")
|
|
ax.set_ylabel("Score")
|
|
ax.set_title(dataset_name)
|
|
ax.set_xticks(x)
|
|
ax.set_xticklabels(labels, rotation=-45)
|
|
ax.legend(fontsize=7)
|
|
|
|
|
|
for j in range(i + 1, len(axs)):
|
|
fig.delaxes(axs[j])
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
|