Spaces:
Sleeping
Sleeping
File size: 9,014 Bytes
25ce9a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
'''
This module contains utility functions for plotting
'''
# Handling files
import os
# Random
import random
# Handling images and visualization
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
# Confusion matrix
from sklearn.metrics import confusion_matrix
def plot_samples(
data_path,
sample_classes=["Tyrannosaurus Rex", "Pteranodon", "Triceratops"],
img_per_class=3,
show=False,
save_path=None
):
'''
Plot random samples of dinosaur species
Args:
data_path (str) : path to dataset
sample_classes (list of str): classes (i.e. dinosaur species) we want to display.
Default list: "Tyrannosaurus Rex", "Pteranodon",
"Triceratops"
img_per_class (int) : number of samples to show from each class,
default value is 3
show (bool) : decide to show the plot or not, default: False
save_path (str) : path to save the plot if provided, default: None
Returns:
None
'''
# Set up
plt.figure(figsize=(12, 6))
num_sample_classes = len(sample_classes)
with plt.rc_context(
rc={
"axes.grid": False,
"axes.spines.top": False,
"axes.spines.right": False,
"axes.spines.left": False,
"axes.spines.bottom": False
}
):
for y, cls in enumerate(sample_classes):
# Get sample image paths
imgs = os.listdir(os.path.join(data_path, cls))
samples = random.sample(imgs, img_per_class)
# Plotting
for i, img in enumerate(samples):
plt_idx = i * num_sample_classes + y + 1
plt.subplot(img_per_class, num_sample_classes, plt_idx)
plt.imshow(Image.open(os.path.join(data_path, cls, img)))
plt.axis('off')
if i == 0:
plt.title(cls)
plt.tight_layout()
plt.suptitle("Sample Images", fontsize=16)
plt.subplots_adjust(top=0.88)
if save_path:
plt.savefig(save_path)
if show:
plt.show()
def plot_class_balance(data_path, show=False, save_path=None):
'''
Plot class balance from a given path
Args:
data_path (str): path to dataset
show (bool) : decide to show the plot or not, default: False
save_path (str): path to save the plot if provided, default: None
Returns:
None
'''
# Set up
plt.figure(figsize=(12, 6))
classes = os.listdir(data_path)
img_count = [
len(os.listdir(os.path.join(data_path, cls))) for cls in classes
]
class_frequency = [(cnt/sum(img_count))*100 for cnt in img_count]
# Sort descending
sorted_data = sorted(
zip(classes, class_frequency),
key=lambda x: x[1],
reverse=True
)
sorted_classes, sorted_frequency = zip(*sorted_data)
# Plotting
plt.bar(x=sorted_classes, height=sorted_frequency)
plt.title("Class Frequency", fontsize=16)
plt.xlabel("Class")
plt.ylabel("Frequency (%)")
plt.xticks(rotation=45, ha="right")
if save_path:
plt.savefig(save_path, bbox_inches="tight")
if show:
plt.show()
def plot_confusion_matrix(
y_true, y_pred, display_labels, top_k=10, figsize=(18, 24),
normalize="true", show=False, save_path=None
):
'''
Plot confusion matrix and a table of top_k misclassified pairs
Args:
y_true (lst) : true labels
y_pred (lst) : predictions from model
display_labels (lst): labels to display
top_k (int) : number of classes with highest confusion to include
in confusion matrix, default: 10
figsize (tuple) : figure size, default: (18, 24)
fontsize (float) : size of texts for labels
normalize (str) : option to normalize confusion matrix
(same in sklearn.metrics.confusion_matrix),
but only accepts 2 value: "true" (normalize
by row) and None (no normalization), default: "true"
show (bool) : decide to show the plot or not, default: False
save_path (str) : path to save the plot if provided, default: None
Returns:
None
'''
# Full confusion matrix
cm = confusion_matrix(y_true, y_pred, normalize=normalize)
# Find (i, j) indices (i != j) that have highest confusion
confusions = []
for i in range(len(cm)):
for j in range(len(cm)):
if i != j and cm[i][j] > 0:
confusions.append((i, j, cm[i][j]))
# Sorting and find top-k confused pairs
top_confusions = sorted(confusions, key=lambda x: x[2], reverse=True)[:top_k]
# Set up plots
fig, axes = plt.subplots(
nrows=2, figsize=figsize, gridspec_kw={"height_ratios": [4, 1]}
)
# Plot confusion matrix
sns.heatmap(
cm, cmap="Blues", linewidths=0.5, linecolor="gray",
xticklabels=display_labels, yticklabels=display_labels,
cbar_kws={"label": "Proportion" if normalize else "Count"}, ax=axes[0]
)
axes[0].set_title("Confusion Matrix", fontsize=16, fontweight="bold", pad=20)
axes[0].set_xlabel("Predicted Label", fontsize=14)
axes[0].set_ylabel("True Label", fontsize=14)
axes[0].tick_params(axis="x", labelsize=14)
axes[0].tick_params(axis="y", labelsize=14)
plt.setp(axes[0].get_xticklabels(), rotation=90, ha="right")
# Table of top_k misclassified pairs
columns = ["Ground Truth", "Predicted", "Proportion" if normalize else "Count"]
data = [
[display_labels[i], display_labels[j], f"{v:.2f}" if normalize else int(v)]
for i, j, v in top_confusions
]
axes[1].axis("off")
table = axes[1].table(
cellText=data, colLabels=columns, loc="center", cellLoc="center",
colColours=["#d3d3d3"] * len(columns), bbox=[0, 0, 1, 1]
)
table.auto_set_font_size(False)
table.set_fontsize(14)
table.scale(1, 2)
axes[1].set_title(
f"Top-{top_k} Misclassified Pairs", fontsize=14, fontweight="bold", pad=10
)
plt.tight_layout(h_pad=5)
if save_path:
plt.savefig(save_path, bbox_inches="tight")
if show:
plt.show()
def plot_training_progress(
avg_training_losses,
avg_val_losses,
accuracy_scores,
f1_scores,
lr_changes,
show=False,
save_path=None
):
'''
Plot training process over epochs, specifically, 3 subplots are created:
- One plot for average train and validation loss
- One plot for accuracy and weighted F1 score on validation data
- One plot for learning rates
Args:
avg_training_losses (lst): average training loss
avg_val_losses (lst) : average validation loss
accuracy_scores (lst) : accuracy on validation data
f1_scores (lst) : weighted F1 score on validation data
lr_changes (lst) : learning rates
show (bool) : decide to show the plot or not, default: False
save_path (str) : path to save the plot if provided, default: None
Returns:
None
'''
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 15))
n_epochs = [i+1 for i in range(len(avg_training_losses))]
# Avg Training vs Validation loss
axes[0].plot(n_epochs, avg_training_losses, label="Train", color="blue")
axes[0].plot(n_epochs, avg_val_losses, label="Validation", color="red")
axes[0].set(
xlabel="Epoch",
ylabel="Average Loss",
title="Average Training vs Validation Loss"
)
axes[0].legend(loc="upper right")
axes[0].grid(True)
# Accuracy vs Weighted F1 score on validation data
axes[1].plot(n_epochs, accuracy_scores, label="Accuracy", color="blue")
axes[1].plot(n_epochs, f1_scores, label="Weighted F1 Score", color="red")
axes[1].set(
xlabel="Epoch",
ylabel="",
title="Validation Accuracy vs Weighted F1 Score"
)
axes[1].legend(loc="lower right")
axes[1].grid(True)
# Learning rate
axes[2].plot(n_epochs, lr_changes)
axes[2].set(
xlabel="Epoch",
ylabel="Learning Rate",
title="Learning Rate Changes"
)
axes[2].grid(True)
plt.suptitle("Training Process", fontsize=16)
plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches="tight")
if show:
plt.show() |