davidpomerenke's picture
Upload from GitHub Actions: Merge pull request #18 from datenlabor-bmz/pr-17
a0d1624 verified
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.special import logit
df = pd.read_json("../results.json")
df = df[df["metric"] != "chrf"]
df = df.groupby(["task", "metric", "bcp_47"]).agg({"score": "mean"}).reset_index()
# Apply logit transformation to classification scores to reduce skewness
def transform_classification_scores(row):
if row["task"] == "classification":
# Avoid division by zero and infinite values by clipping
score = np.clip(row["score"], 0.001, 0.999)
# Apply logit transformation (log(p/(1-p)))
return logit(score)
else:
return row["score"]
df["score"] = df.apply(transform_classification_scores, axis=1)
# Create a pivot table with tasks as columns and languages as rows
pivot_df = df.pivot_table(
values="score", index="bcp_47", columns="task", aggfunc="mean"
)
# Sort and filter tasks
ordered_tasks = [
"translation_from",
"translation_to",
"classification",
"mmlu",
"arc",
"mgsm",
]
# Drop 'truthfulqa' if present and reindex columns
pivot_df = pivot_df[[task for task in ordered_tasks if task in pivot_df.columns]]
# Calculate correlation matrix
correlation_matrix = pivot_df.corr()
# Create the correlation plot
plt.figure(figsize=(8, 6))
# Create mask for upper triangle including diagonal to show only lower triangle
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))
# Create a heatmap
sns.heatmap(
correlation_matrix,
annot=True,
cmap="Blues",
center=0,
square=True,
mask=mask,
cbar_kws={"shrink": 0.8},
fmt=".3f",
)
plt.xlabel("Tasks", fontsize=12)
plt.ylabel("Tasks", fontsize=12)
plt.xticks(rotation=45, ha="right")
plt.yticks(rotation=0)
plt.tight_layout()
# Save the plot
plt.savefig("task_correlation_matrix.png", dpi=300, bbox_inches="tight")
plt.show()
# Print correlation values for reference
print("Correlation Matrix:")
print("Note: Classification scores have been logit-transformed to reduce skewness")
print(correlation_matrix.round(3))
# Also create a scatter plot matrix for pairwise relationships with highlighted languages
highlighted_languages = ["en", "zh", "hi", "es", "ar"]
# Create color mapping
def get_color_and_label(lang_code):
if lang_code in highlighted_languages:
color_map = {
"en": "red",
"zh": "blue",
"hi": "green",
"es": "orange",
"ar": "purple",
}
return color_map[lang_code], lang_code
else:
return "lightgray", "Other"
# Create custom scatter plot matrix
tasks = pivot_df.columns.tolist()
n_tasks = len(tasks)
fig, axes = plt.subplots(n_tasks, n_tasks, figsize=(15, 12))
fig.suptitle("Pairwise Task Performance", fontsize=16, fontweight="bold")
# Create legend elements
legend_elements = []
for lang in highlighted_languages:
color, _ = get_color_and_label(lang)
legend_elements.append(
plt.Line2D(
[0],
[0],
marker="o",
color="w",
markerfacecolor=color,
markersize=8,
label=lang,
)
)
legend_elements.append(
plt.Line2D(
[0],
[0],
marker="o",
color="w",
markerfacecolor="lightgray",
markersize=8,
label="Other",
)
)
for i, task_y in enumerate(tasks):
for j, task_x in enumerate(tasks):
ax = axes[i, j]
if i == j:
# Diagonal: histogram
task_data = pivot_df[task_y].dropna()
colors = [get_color_and_label(lang)[0] for lang in task_data.index]
ax.hist(task_data, bins=20, alpha=0.7, color="skyblue", edgecolor="black")
ax.set_title(f"{task_y}", fontsize=10)
else:
# Off-diagonal: scatter plot
for lang_code in pivot_df.index:
if pd.notna(pivot_df.loc[lang_code, task_x]) and pd.notna(
pivot_df.loc[lang_code, task_y]
):
color, _ = get_color_and_label(lang_code)
alpha = 0.8 if lang_code in highlighted_languages else 0.3
size = 50 if lang_code in highlighted_languages else 20
ax.scatter(
pivot_df.loc[lang_code, task_x],
pivot_df.loc[lang_code, task_y],
c=color,
alpha=alpha,
s=size,
)
# Set labels
if i == n_tasks - 1:
ax.set_xlabel(task_x, fontsize=10)
if j == 0:
ax.set_ylabel(task_y, fontsize=10)
# Remove tick labels except for edges
if i != n_tasks - 1:
ax.set_xticklabels([])
if j != 0:
ax.set_yticklabels([])
# Add legend
fig.legend(
handles=legend_elements,
loc="lower center",
bbox_to_anchor=(0.5, -0.05),
ncol=len(legend_elements),
frameon=False,
fontsize=10,
handletextpad=0.5,
columnspacing=1.0,
)
plt.tight_layout()
plt.savefig("task_scatter_matrix.png", dpi=300, bbox_inches="tight")
plt.show()