import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import numpy as np from scipy.special import logit df = pd.read_json("../results.json") df = df[df["metric"] != "chrf"] df = df.groupby(["task", "metric", "bcp_47"]).agg({"score": "mean"}).reset_index() # Apply logit transformation to classification scores to reduce skewness def transform_classification_scores(row): if row["task"] == "classification": # Avoid division by zero and infinite values by clipping score = np.clip(row["score"], 0.001, 0.999) # Apply logit transformation (log(p/(1-p))) return logit(score) else: return row["score"] df["score"] = df.apply(transform_classification_scores, axis=1) # Create a pivot table with tasks as columns and languages as rows pivot_df = df.pivot_table( values="score", index="bcp_47", columns="task", aggfunc="mean" ) # Sort and filter tasks ordered_tasks = [ "translation_from", "translation_to", "classification", "mmlu", "arc", "mgsm", ] # Drop 'truthfulqa' if present and reindex columns pivot_df = pivot_df[[task for task in ordered_tasks if task in pivot_df.columns]] # Calculate correlation matrix correlation_matrix = pivot_df.corr() # Create the correlation plot plt.figure(figsize=(8, 6)) # Create mask for upper triangle including diagonal to show only lower triangle mask = np.triu(np.ones_like(correlation_matrix, dtype=bool)) # Create a heatmap sns.heatmap( correlation_matrix, annot=True, cmap="Blues", center=0, square=True, mask=mask, cbar_kws={"shrink": 0.8}, fmt=".3f", ) plt.xlabel("Tasks", fontsize=12) plt.ylabel("Tasks", fontsize=12) plt.xticks(rotation=45, ha="right") plt.yticks(rotation=0) plt.tight_layout() # Save the plot plt.savefig("task_correlation_matrix.png", dpi=300, bbox_inches="tight") plt.show() # Print correlation values for reference print("Correlation Matrix:") print("Note: Classification scores have been logit-transformed to reduce skewness") print(correlation_matrix.round(3)) # Also create a scatter plot matrix for pairwise relationships with highlighted languages highlighted_languages = ["en", "zh", "hi", "es", "ar"] # Create color mapping def get_color_and_label(lang_code): if lang_code in highlighted_languages: color_map = { "en": "red", "zh": "blue", "hi": "green", "es": "orange", "ar": "purple", } return color_map[lang_code], lang_code else: return "lightgray", "Other" # Create custom scatter plot matrix tasks = pivot_df.columns.tolist() n_tasks = len(tasks) fig, axes = plt.subplots(n_tasks, n_tasks, figsize=(15, 12)) fig.suptitle("Pairwise Task Performance", fontsize=16, fontweight="bold") # Create legend elements legend_elements = [] for lang in highlighted_languages: color, _ = get_color_and_label(lang) legend_elements.append( plt.Line2D( [0], [0], marker="o", color="w", markerfacecolor=color, markersize=8, label=lang, ) ) legend_elements.append( plt.Line2D( [0], [0], marker="o", color="w", markerfacecolor="lightgray", markersize=8, label="Other", ) ) for i, task_y in enumerate(tasks): for j, task_x in enumerate(tasks): ax = axes[i, j] if i == j: # Diagonal: histogram task_data = pivot_df[task_y].dropna() colors = [get_color_and_label(lang)[0] for lang in task_data.index] ax.hist(task_data, bins=20, alpha=0.7, color="skyblue", edgecolor="black") ax.set_title(f"{task_y}", fontsize=10) else: # Off-diagonal: scatter plot for lang_code in pivot_df.index: if pd.notna(pivot_df.loc[lang_code, task_x]) and pd.notna( pivot_df.loc[lang_code, task_y] ): color, _ = get_color_and_label(lang_code) alpha = 0.8 if lang_code in highlighted_languages else 0.3 size = 50 if lang_code in highlighted_languages else 20 ax.scatter( pivot_df.loc[lang_code, task_x], pivot_df.loc[lang_code, task_y], c=color, alpha=alpha, s=size, ) # Set labels if i == n_tasks - 1: ax.set_xlabel(task_x, fontsize=10) if j == 0: ax.set_ylabel(task_y, fontsize=10) # Remove tick labels except for edges if i != n_tasks - 1: ax.set_xticklabels([]) if j != 0: ax.set_yticklabels([]) # Add legend fig.legend( handles=legend_elements, loc="lower center", bbox_to_anchor=(0.5, -0.05), ncol=len(legend_elements), frameon=False, fontsize=10, handletextpad=0.5, columnspacing=1.0, ) plt.tight_layout() plt.savefig("task_scatter_matrix.png", dpi=300, bbox_inches="tight") plt.show()