Spaces:
Sleeping
Sleeping
| ''' | |
| This module contains utility functions for plotting | |
| ''' | |
| # Handling files | |
| import os | |
| # Random | |
| import random | |
| # Handling images and visualization | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| # Confusion matrix | |
| from sklearn.metrics import confusion_matrix | |
| def plot_samples( | |
| data_path, | |
| sample_classes=["Tyrannosaurus Rex", "Pteranodon", "Triceratops"], | |
| img_per_class=3, | |
| show=False, | |
| save_path=None | |
| ): | |
| ''' | |
| Plot random samples of dinosaur species | |
| Args: | |
| data_path (str) : path to dataset | |
| sample_classes (list of str): classes (i.e. dinosaur species) we want to display. | |
| Default list: "Tyrannosaurus Rex", "Pteranodon", | |
| "Triceratops" | |
| img_per_class (int) : number of samples to show from each class, | |
| default value is 3 | |
| show (bool) : decide to show the plot or not, default: False | |
| save_path (str) : path to save the plot if provided, default: None | |
| Returns: | |
| None | |
| ''' | |
| # Set up | |
| plt.figure(figsize=(12, 6)) | |
| num_sample_classes = len(sample_classes) | |
| with plt.rc_context( | |
| rc={ | |
| "axes.grid": False, | |
| "axes.spines.top": False, | |
| "axes.spines.right": False, | |
| "axes.spines.left": False, | |
| "axes.spines.bottom": False | |
| } | |
| ): | |
| for y, cls in enumerate(sample_classes): | |
| # Get sample image paths | |
| imgs = os.listdir(os.path.join(data_path, cls)) | |
| samples = random.sample(imgs, img_per_class) | |
| # Plotting | |
| for i, img in enumerate(samples): | |
| plt_idx = i * num_sample_classes + y + 1 | |
| plt.subplot(img_per_class, num_sample_classes, plt_idx) | |
| plt.imshow(Image.open(os.path.join(data_path, cls, img))) | |
| plt.axis('off') | |
| if i == 0: | |
| plt.title(cls) | |
| plt.tight_layout() | |
| plt.suptitle("Sample Images", fontsize=16) | |
| plt.subplots_adjust(top=0.88) | |
| if save_path: | |
| plt.savefig(save_path) | |
| if show: | |
| plt.show() | |
| def plot_class_balance(data_path, show=False, save_path=None): | |
| ''' | |
| Plot class balance from a given path | |
| Args: | |
| data_path (str): path to dataset | |
| show (bool) : decide to show the plot or not, default: False | |
| save_path (str): path to save the plot if provided, default: None | |
| Returns: | |
| None | |
| ''' | |
| # Set up | |
| plt.figure(figsize=(12, 6)) | |
| classes = os.listdir(data_path) | |
| img_count = [ | |
| len(os.listdir(os.path.join(data_path, cls))) for cls in classes | |
| ] | |
| class_frequency = [(cnt/sum(img_count))*100 for cnt in img_count] | |
| # Sort descending | |
| sorted_data = sorted( | |
| zip(classes, class_frequency), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| sorted_classes, sorted_frequency = zip(*sorted_data) | |
| # Plotting | |
| plt.bar(x=sorted_classes, height=sorted_frequency) | |
| plt.title("Class Frequency", fontsize=16) | |
| plt.xlabel("Class") | |
| plt.ylabel("Frequency (%)") | |
| plt.xticks(rotation=45, ha="right") | |
| if save_path: | |
| plt.savefig(save_path, bbox_inches="tight") | |
| if show: | |
| plt.show() | |
| def plot_confusion_matrix( | |
| y_true, y_pred, display_labels, top_k=10, figsize=(18, 24), | |
| normalize="true", show=False, save_path=None | |
| ): | |
| ''' | |
| Plot confusion matrix and a table of top_k misclassified pairs | |
| Args: | |
| y_true (lst) : true labels | |
| y_pred (lst) : predictions from model | |
| display_labels (lst): labels to display | |
| top_k (int) : number of classes with highest confusion to include | |
| in confusion matrix, default: 10 | |
| figsize (tuple) : figure size, default: (18, 24) | |
| fontsize (float) : size of texts for labels | |
| normalize (str) : option to normalize confusion matrix | |
| (same in sklearn.metrics.confusion_matrix), | |
| but only accepts 2 value: "true" (normalize | |
| by row) and None (no normalization), default: "true" | |
| show (bool) : decide to show the plot or not, default: False | |
| save_path (str) : path to save the plot if provided, default: None | |
| Returns: | |
| None | |
| ''' | |
| # Full confusion matrix | |
| cm = confusion_matrix(y_true, y_pred, normalize=normalize) | |
| # Find (i, j) indices (i != j) that have highest confusion | |
| confusions = [] | |
| for i in range(len(cm)): | |
| for j in range(len(cm)): | |
| if i != j and cm[i][j] > 0: | |
| confusions.append((i, j, cm[i][j])) | |
| # Sorting and find top-k confused pairs | |
| top_confusions = sorted(confusions, key=lambda x: x[2], reverse=True)[:top_k] | |
| # Set up plots | |
| fig, axes = plt.subplots( | |
| nrows=2, figsize=figsize, gridspec_kw={"height_ratios": [4, 1]} | |
| ) | |
| # Plot confusion matrix | |
| sns.heatmap( | |
| cm, cmap="Blues", linewidths=0.5, linecolor="gray", | |
| xticklabels=display_labels, yticklabels=display_labels, | |
| cbar_kws={"label": "Proportion" if normalize else "Count"}, ax=axes[0] | |
| ) | |
| axes[0].set_title("Confusion Matrix", fontsize=16, fontweight="bold", pad=20) | |
| axes[0].set_xlabel("Predicted Label", fontsize=14) | |
| axes[0].set_ylabel("True Label", fontsize=14) | |
| axes[0].tick_params(axis="x", labelsize=14) | |
| axes[0].tick_params(axis="y", labelsize=14) | |
| plt.setp(axes[0].get_xticklabels(), rotation=90, ha="right") | |
| # Table of top_k misclassified pairs | |
| columns = ["Ground Truth", "Predicted", "Proportion" if normalize else "Count"] | |
| data = [ | |
| [display_labels[i], display_labels[j], f"{v:.2f}" if normalize else int(v)] | |
| for i, j, v in top_confusions | |
| ] | |
| axes[1].axis("off") | |
| table = axes[1].table( | |
| cellText=data, colLabels=columns, loc="center", cellLoc="center", | |
| colColours=["#d3d3d3"] * len(columns), bbox=[0, 0, 1, 1] | |
| ) | |
| table.auto_set_font_size(False) | |
| table.set_fontsize(14) | |
| table.scale(1, 2) | |
| axes[1].set_title( | |
| f"Top-{top_k} Misclassified Pairs", fontsize=14, fontweight="bold", pad=10 | |
| ) | |
| plt.tight_layout(h_pad=5) | |
| if save_path: | |
| plt.savefig(save_path, bbox_inches="tight") | |
| if show: | |
| plt.show() | |
| def plot_training_progress( | |
| avg_training_losses, | |
| avg_val_losses, | |
| accuracy_scores, | |
| f1_scores, | |
| lr_changes, | |
| show=False, | |
| save_path=None | |
| ): | |
| ''' | |
| Plot training process over epochs, specifically, 3 subplots are created: | |
| - One plot for average train and validation loss | |
| - One plot for accuracy and weighted F1 score on validation data | |
| - One plot for learning rates | |
| Args: | |
| avg_training_losses (lst): average training loss | |
| avg_val_losses (lst) : average validation loss | |
| accuracy_scores (lst) : accuracy on validation data | |
| f1_scores (lst) : weighted F1 score on validation data | |
| lr_changes (lst) : learning rates | |
| show (bool) : decide to show the plot or not, default: False | |
| save_path (str) : path to save the plot if provided, default: None | |
| Returns: | |
| None | |
| ''' | |
| fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 15)) | |
| n_epochs = [i+1 for i in range(len(avg_training_losses))] | |
| # Avg Training vs Validation loss | |
| axes[0].plot(n_epochs, avg_training_losses, label="Train", color="blue") | |
| axes[0].plot(n_epochs, avg_val_losses, label="Validation", color="red") | |
| axes[0].set( | |
| xlabel="Epoch", | |
| ylabel="Average Loss", | |
| title="Average Training vs Validation Loss" | |
| ) | |
| axes[0].legend(loc="upper right") | |
| axes[0].grid(True) | |
| # Accuracy vs Weighted F1 score on validation data | |
| axes[1].plot(n_epochs, accuracy_scores, label="Accuracy", color="blue") | |
| axes[1].plot(n_epochs, f1_scores, label="Weighted F1 Score", color="red") | |
| axes[1].set( | |
| xlabel="Epoch", | |
| ylabel="", | |
| title="Validation Accuracy vs Weighted F1 Score" | |
| ) | |
| axes[1].legend(loc="lower right") | |
| axes[1].grid(True) | |
| # Learning rate | |
| axes[2].plot(n_epochs, lr_changes) | |
| axes[2].set( | |
| xlabel="Epoch", | |
| ylabel="Learning Rate", | |
| title="Learning Rate Changes" | |
| ) | |
| axes[2].grid(True) | |
| plt.suptitle("Training Process", fontsize=16) | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, bbox_inches="tight") | |
| if show: | |
| plt.show() |