dinosaur_project / src /plot_utils.py
lucvantien1211's picture
Upload src folder, which contain python module and script
25ce9a0 verified
'''
This module contains utility functions for plotting
'''
# Handling files
import os
# Random
import random
# Handling images and visualization
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
# Confusion matrix
from sklearn.metrics import confusion_matrix
def plot_samples(
data_path,
sample_classes=["Tyrannosaurus Rex", "Pteranodon", "Triceratops"],
img_per_class=3,
show=False,
save_path=None
):
'''
Plot random samples of dinosaur species
Args:
data_path (str) : path to dataset
sample_classes (list of str): classes (i.e. dinosaur species) we want to display.
Default list: "Tyrannosaurus Rex", "Pteranodon",
"Triceratops"
img_per_class (int) : number of samples to show from each class,
default value is 3
show (bool) : decide to show the plot or not, default: False
save_path (str) : path to save the plot if provided, default: None
Returns:
None
'''
# Set up
plt.figure(figsize=(12, 6))
num_sample_classes = len(sample_classes)
with plt.rc_context(
rc={
"axes.grid": False,
"axes.spines.top": False,
"axes.spines.right": False,
"axes.spines.left": False,
"axes.spines.bottom": False
}
):
for y, cls in enumerate(sample_classes):
# Get sample image paths
imgs = os.listdir(os.path.join(data_path, cls))
samples = random.sample(imgs, img_per_class)
# Plotting
for i, img in enumerate(samples):
plt_idx = i * num_sample_classes + y + 1
plt.subplot(img_per_class, num_sample_classes, plt_idx)
plt.imshow(Image.open(os.path.join(data_path, cls, img)))
plt.axis('off')
if i == 0:
plt.title(cls)
plt.tight_layout()
plt.suptitle("Sample Images", fontsize=16)
plt.subplots_adjust(top=0.88)
if save_path:
plt.savefig(save_path)
if show:
plt.show()
def plot_class_balance(data_path, show=False, save_path=None):
'''
Plot class balance from a given path
Args:
data_path (str): path to dataset
show (bool) : decide to show the plot or not, default: False
save_path (str): path to save the plot if provided, default: None
Returns:
None
'''
# Set up
plt.figure(figsize=(12, 6))
classes = os.listdir(data_path)
img_count = [
len(os.listdir(os.path.join(data_path, cls))) for cls in classes
]
class_frequency = [(cnt/sum(img_count))*100 for cnt in img_count]
# Sort descending
sorted_data = sorted(
zip(classes, class_frequency),
key=lambda x: x[1],
reverse=True
)
sorted_classes, sorted_frequency = zip(*sorted_data)
# Plotting
plt.bar(x=sorted_classes, height=sorted_frequency)
plt.title("Class Frequency", fontsize=16)
plt.xlabel("Class")
plt.ylabel("Frequency (%)")
plt.xticks(rotation=45, ha="right")
if save_path:
plt.savefig(save_path, bbox_inches="tight")
if show:
plt.show()
def plot_confusion_matrix(
y_true, y_pred, display_labels, top_k=10, figsize=(18, 24),
normalize="true", show=False, save_path=None
):
'''
Plot confusion matrix and a table of top_k misclassified pairs
Args:
y_true (lst) : true labels
y_pred (lst) : predictions from model
display_labels (lst): labels to display
top_k (int) : number of classes with highest confusion to include
in confusion matrix, default: 10
figsize (tuple) : figure size, default: (18, 24)
fontsize (float) : size of texts for labels
normalize (str) : option to normalize confusion matrix
(same in sklearn.metrics.confusion_matrix),
but only accepts 2 value: "true" (normalize
by row) and None (no normalization), default: "true"
show (bool) : decide to show the plot or not, default: False
save_path (str) : path to save the plot if provided, default: None
Returns:
None
'''
# Full confusion matrix
cm = confusion_matrix(y_true, y_pred, normalize=normalize)
# Find (i, j) indices (i != j) that have highest confusion
confusions = []
for i in range(len(cm)):
for j in range(len(cm)):
if i != j and cm[i][j] > 0:
confusions.append((i, j, cm[i][j]))
# Sorting and find top-k confused pairs
top_confusions = sorted(confusions, key=lambda x: x[2], reverse=True)[:top_k]
# Set up plots
fig, axes = plt.subplots(
nrows=2, figsize=figsize, gridspec_kw={"height_ratios": [4, 1]}
)
# Plot confusion matrix
sns.heatmap(
cm, cmap="Blues", linewidths=0.5, linecolor="gray",
xticklabels=display_labels, yticklabels=display_labels,
cbar_kws={"label": "Proportion" if normalize else "Count"}, ax=axes[0]
)
axes[0].set_title("Confusion Matrix", fontsize=16, fontweight="bold", pad=20)
axes[0].set_xlabel("Predicted Label", fontsize=14)
axes[0].set_ylabel("True Label", fontsize=14)
axes[0].tick_params(axis="x", labelsize=14)
axes[0].tick_params(axis="y", labelsize=14)
plt.setp(axes[0].get_xticklabels(), rotation=90, ha="right")
# Table of top_k misclassified pairs
columns = ["Ground Truth", "Predicted", "Proportion" if normalize else "Count"]
data = [
[display_labels[i], display_labels[j], f"{v:.2f}" if normalize else int(v)]
for i, j, v in top_confusions
]
axes[1].axis("off")
table = axes[1].table(
cellText=data, colLabels=columns, loc="center", cellLoc="center",
colColours=["#d3d3d3"] * len(columns), bbox=[0, 0, 1, 1]
)
table.auto_set_font_size(False)
table.set_fontsize(14)
table.scale(1, 2)
axes[1].set_title(
f"Top-{top_k} Misclassified Pairs", fontsize=14, fontweight="bold", pad=10
)
plt.tight_layout(h_pad=5)
if save_path:
plt.savefig(save_path, bbox_inches="tight")
if show:
plt.show()
def plot_training_progress(
avg_training_losses,
avg_val_losses,
accuracy_scores,
f1_scores,
lr_changes,
show=False,
save_path=None
):
'''
Plot training process over epochs, specifically, 3 subplots are created:
- One plot for average train and validation loss
- One plot for accuracy and weighted F1 score on validation data
- One plot for learning rates
Args:
avg_training_losses (lst): average training loss
avg_val_losses (lst) : average validation loss
accuracy_scores (lst) : accuracy on validation data
f1_scores (lst) : weighted F1 score on validation data
lr_changes (lst) : learning rates
show (bool) : decide to show the plot or not, default: False
save_path (str) : path to save the plot if provided, default: None
Returns:
None
'''
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 15))
n_epochs = [i+1 for i in range(len(avg_training_losses))]
# Avg Training vs Validation loss
axes[0].plot(n_epochs, avg_training_losses, label="Train", color="blue")
axes[0].plot(n_epochs, avg_val_losses, label="Validation", color="red")
axes[0].set(
xlabel="Epoch",
ylabel="Average Loss",
title="Average Training vs Validation Loss"
)
axes[0].legend(loc="upper right")
axes[0].grid(True)
# Accuracy vs Weighted F1 score on validation data
axes[1].plot(n_epochs, accuracy_scores, label="Accuracy", color="blue")
axes[1].plot(n_epochs, f1_scores, label="Weighted F1 Score", color="red")
axes[1].set(
xlabel="Epoch",
ylabel="",
title="Validation Accuracy vs Weighted F1 Score"
)
axes[1].legend(loc="lower right")
axes[1].grid(True)
# Learning rate
axes[2].plot(n_epochs, lr_changes)
axes[2].set(
xlabel="Epoch",
ylabel="Learning Rate",
title="Learning Rate Changes"
)
axes[2].grid(True)
plt.suptitle("Training Process", fontsize=16)
plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches="tight")
if show:
plt.show()