Spaces:
Runtime error
Runtime error
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
from scipy.special import logit | |
df = pd.read_json("../results.json") | |
df = df[df["metric"] != "chrf"] | |
df = df.groupby(["task", "metric", "bcp_47"]).agg({"score": "mean"}).reset_index() | |
# Apply logit transformation to classification scores to reduce skewness | |
def transform_classification_scores(row): | |
if row["task"] == "classification": | |
# Avoid division by zero and infinite values by clipping | |
score = np.clip(row["score"], 0.001, 0.999) | |
# Apply logit transformation (log(p/(1-p))) | |
return logit(score) | |
else: | |
return row["score"] | |
df["score"] = df.apply(transform_classification_scores, axis=1) | |
# Create a pivot table with tasks as columns and languages as rows | |
pivot_df = df.pivot_table( | |
values="score", index="bcp_47", columns="task", aggfunc="mean" | |
) | |
# Sort and filter tasks | |
ordered_tasks = [ | |
"translation_from", | |
"translation_to", | |
"classification", | |
"mmlu", | |
"arc", | |
"mgsm", | |
] | |
# Drop 'truthfulqa' if present and reindex columns | |
pivot_df = pivot_df[[task for task in ordered_tasks if task in pivot_df.columns]] | |
# Calculate correlation matrix | |
correlation_matrix = pivot_df.corr() | |
# Create the correlation plot | |
plt.figure(figsize=(8, 6)) | |
# Create mask for upper triangle including diagonal to show only lower triangle | |
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool)) | |
# Create a heatmap | |
sns.heatmap( | |
correlation_matrix, | |
annot=True, | |
cmap="Blues", | |
center=0, | |
square=True, | |
mask=mask, | |
cbar_kws={"shrink": 0.8}, | |
fmt=".3f", | |
) | |
plt.xlabel("Tasks", fontsize=12) | |
plt.ylabel("Tasks", fontsize=12) | |
plt.xticks(rotation=45, ha="right") | |
plt.yticks(rotation=0) | |
plt.tight_layout() | |
# Save the plot | |
plt.savefig("task_correlation_matrix.png", dpi=300, bbox_inches="tight") | |
plt.show() | |
# Print correlation values for reference | |
print("Correlation Matrix:") | |
print("Note: Classification scores have been logit-transformed to reduce skewness") | |
print(correlation_matrix.round(3)) | |
# Also create a scatter plot matrix for pairwise relationships with highlighted languages | |
highlighted_languages = ["en", "zh", "hi", "es", "ar"] | |
# Create color mapping | |
def get_color_and_label(lang_code): | |
if lang_code in highlighted_languages: | |
color_map = { | |
"en": "red", | |
"zh": "blue", | |
"hi": "green", | |
"es": "orange", | |
"ar": "purple", | |
} | |
return color_map[lang_code], lang_code | |
else: | |
return "lightgray", "Other" | |
# Create custom scatter plot matrix | |
tasks = pivot_df.columns.tolist() | |
n_tasks = len(tasks) | |
fig, axes = plt.subplots(n_tasks, n_tasks, figsize=(15, 12)) | |
fig.suptitle("Pairwise Task Performance", fontsize=16, fontweight="bold") | |
# Create legend elements | |
legend_elements = [] | |
for lang in highlighted_languages: | |
color, _ = get_color_and_label(lang) | |
legend_elements.append( | |
plt.Line2D( | |
[0], | |
[0], | |
marker="o", | |
color="w", | |
markerfacecolor=color, | |
markersize=8, | |
label=lang, | |
) | |
) | |
legend_elements.append( | |
plt.Line2D( | |
[0], | |
[0], | |
marker="o", | |
color="w", | |
markerfacecolor="lightgray", | |
markersize=8, | |
label="Other", | |
) | |
) | |
for i, task_y in enumerate(tasks): | |
for j, task_x in enumerate(tasks): | |
ax = axes[i, j] | |
if i == j: | |
# Diagonal: histogram | |
task_data = pivot_df[task_y].dropna() | |
colors = [get_color_and_label(lang)[0] for lang in task_data.index] | |
ax.hist(task_data, bins=20, alpha=0.7, color="skyblue", edgecolor="black") | |
ax.set_title(f"{task_y}", fontsize=10) | |
else: | |
# Off-diagonal: scatter plot | |
for lang_code in pivot_df.index: | |
if pd.notna(pivot_df.loc[lang_code, task_x]) and pd.notna( | |
pivot_df.loc[lang_code, task_y] | |
): | |
color, _ = get_color_and_label(lang_code) | |
alpha = 0.8 if lang_code in highlighted_languages else 0.3 | |
size = 50 if lang_code in highlighted_languages else 20 | |
ax.scatter( | |
pivot_df.loc[lang_code, task_x], | |
pivot_df.loc[lang_code, task_y], | |
c=color, | |
alpha=alpha, | |
s=size, | |
) | |
# Set labels | |
if i == n_tasks - 1: | |
ax.set_xlabel(task_x, fontsize=10) | |
if j == 0: | |
ax.set_ylabel(task_y, fontsize=10) | |
# Remove tick labels except for edges | |
if i != n_tasks - 1: | |
ax.set_xticklabels([]) | |
if j != 0: | |
ax.set_yticklabels([]) | |
# Add legend | |
fig.legend( | |
handles=legend_elements, | |
loc="lower center", | |
bbox_to_anchor=(0.5, -0.05), | |
ncol=len(legend_elements), | |
frameon=False, | |
fontsize=10, | |
handletextpad=0.5, | |
columnspacing=1.0, | |
) | |
plt.tight_layout() | |
plt.savefig("task_scatter_matrix.png", dpi=300, bbox_inches="tight") | |
plt.show() | |