File size: 9,014 Bytes
25ce9a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
'''

This module contains utility functions for plotting

'''

# Handling files
import os

# Random
import random

# Handling images and visualization
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

# Confusion matrix
from sklearn.metrics import confusion_matrix


def plot_samples(

    data_path,

    sample_classes=["Tyrannosaurus Rex", "Pteranodon", "Triceratops"],

    img_per_class=3,

    show=False,

    save_path=None

):
    '''

    Plot random samples of dinosaur species



    Args:

        data_path (str)             : path to dataset

        sample_classes (list of str): classes (i.e. dinosaur species) we want to display.

                                      Default list: "Tyrannosaurus Rex", "Pteranodon",

                                      "Triceratops"

        img_per_class (int)         : number of samples to show from each class,

                                      default value is 3

        show (bool)                 : decide to show the plot or not, default: False

        save_path (str)             : path to save the plot if provided, default: None



    Returns:

        None

    '''
    
    # Set up
    plt.figure(figsize=(12, 6))
    num_sample_classes = len(sample_classes)

    with plt.rc_context(
        rc={
            "axes.grid": False,
            "axes.spines.top": False,
            "axes.spines.right": False,
            "axes.spines.left": False,
            "axes.spines.bottom": False
        }
    ):
        for y, cls in enumerate(sample_classes):
            # Get sample image paths
            imgs = os.listdir(os.path.join(data_path, cls))
            samples = random.sample(imgs, img_per_class)
    
            # Plotting
            for i, img in enumerate(samples):
                plt_idx = i * num_sample_classes + y + 1
                plt.subplot(img_per_class, num_sample_classes, plt_idx)
                plt.imshow(Image.open(os.path.join(data_path, cls, img)))
                plt.axis('off')
                if i == 0:
                    plt.title(cls)
    
        plt.tight_layout()
        plt.suptitle("Sample Images", fontsize=16)
        plt.subplots_adjust(top=0.88)
    
        if save_path:
            plt.savefig(save_path)
            
    if show:
        plt.show()


def plot_class_balance(data_path, show=False, save_path=None):
    '''

    Plot class balance from a given path



    Args:

        data_path (str): path to dataset

        show (bool)    : decide to show the plot or not, default: False

        save_path (str): path to save the plot if provided, default: None



    Returns:

        None

    '''

    # Set up
    plt.figure(figsize=(12, 6))
    
    classes = os.listdir(data_path)
    img_count = [
        len(os.listdir(os.path.join(data_path, cls))) for cls in classes
    ]
    class_frequency = [(cnt/sum(img_count))*100 for cnt in img_count]

    # Sort descending
    sorted_data = sorted(
        zip(classes, class_frequency),
        key=lambda x: x[1],
        reverse=True
    )
    sorted_classes, sorted_frequency = zip(*sorted_data)

    # Plotting
    plt.bar(x=sorted_classes, height=sorted_frequency)
    plt.title("Class Frequency", fontsize=16)
    plt.xlabel("Class")
    plt.ylabel("Frequency (%)")
    
    plt.xticks(rotation=45, ha="right")

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

    if show:
        plt.show()


def plot_confusion_matrix(

    y_true, y_pred, display_labels, top_k=10, figsize=(18, 24),

    normalize="true", show=False, save_path=None

):
    '''

    Plot confusion matrix and a table of top_k misclassified pairs

    

    Args:

        y_true (lst)        : true labels

        y_pred (lst)        : predictions from model

        display_labels (lst): labels to display

        top_k (int)         : number of classes with highest confusion to include

                              in confusion matrix, default: 10

        figsize (tuple)     : figure size, default: (18, 24)

        fontsize (float)    : size of texts for labels

        normalize (str)     : option to normalize confusion matrix

                              (same in sklearn.metrics.confusion_matrix),

                              but only accepts 2 value: "true" (normalize

                              by row) and None (no normalization), default: "true"

        show (bool)         : decide to show the plot or not, default: False

        save_path (str)     : path to save the plot if provided, default: None



    Returns:

        None

    '''
    
    # Full confusion matrix
    cm = confusion_matrix(y_true, y_pred, normalize=normalize)

    # Find (i, j) indices (i != j) that have highest confusion
    confusions = []
    for i in range(len(cm)):
        for j in range(len(cm)):
            if i != j and cm[i][j] > 0:
                confusions.append((i, j, cm[i][j]))

    # Sorting and find top-k confused pairs
    top_confusions = sorted(confusions, key=lambda x: x[2], reverse=True)[:top_k]

    # Set up plots
    fig, axes = plt.subplots(
        nrows=2, figsize=figsize, gridspec_kw={"height_ratios": [4, 1]}
    )

    # Plot confusion matrix
    sns.heatmap(
        cm, cmap="Blues", linewidths=0.5, linecolor="gray",
        xticklabels=display_labels, yticklabels=display_labels,
        cbar_kws={"label": "Proportion" if normalize else "Count"}, ax=axes[0]
    )

    axes[0].set_title("Confusion Matrix", fontsize=16, fontweight="bold", pad=20)
    axes[0].set_xlabel("Predicted Label", fontsize=14)
    axes[0].set_ylabel("True Label", fontsize=14)
    axes[0].tick_params(axis="x", labelsize=14)
    axes[0].tick_params(axis="y", labelsize=14)
    plt.setp(axes[0].get_xticklabels(), rotation=90, ha="right")

    # Table of top_k misclassified pairs
    columns = ["Ground Truth", "Predicted", "Proportion" if normalize else "Count"]
    data = [
        [display_labels[i], display_labels[j], f"{v:.2f}" if normalize else int(v)]
        for i, j, v in top_confusions
    ]

    axes[1].axis("off")
    table = axes[1].table(
        cellText=data, colLabels=columns, loc="center", cellLoc="center",
        colColours=["#d3d3d3"] * len(columns), bbox=[0, 0, 1, 1]
    )
    table.auto_set_font_size(False)
    table.set_fontsize(14)
    table.scale(1, 2)
    axes[1].set_title(
        f"Top-{top_k} Misclassified Pairs", fontsize=14, fontweight="bold", pad=10
    )
    
    plt.tight_layout(h_pad=5)

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

    if show:
        plt.show()


def plot_training_progress(

    avg_training_losses,

    avg_val_losses,

    accuracy_scores,

    f1_scores,

    lr_changes,

    show=False,

    save_path=None

):
    '''

    Plot training process over epochs, specifically, 3 subplots are created:

    - One plot for average train and validation loss

    - One plot for accuracy and weighted F1 score on validation data

    - One plot for learning rates



    Args:

        avg_training_losses (lst): average training loss

        avg_val_losses (lst)     : average validation loss

        accuracy_scores (lst)    : accuracy on validation data

        f1_scores (lst)          : weighted F1 score on validation data

        lr_changes (lst)         : learning rates

        show (bool)              : decide to show the plot or not, default: False

        save_path (str)          : path to save the plot if provided, default: None



    Returns:

        None

    '''

    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 15))
    n_epochs = [i+1 for i in range(len(avg_training_losses))]

    # Avg Training vs Validation loss
    axes[0].plot(n_epochs, avg_training_losses, label="Train", color="blue")
    axes[0].plot(n_epochs, avg_val_losses, label="Validation", color="red")
    axes[0].set(
        xlabel="Epoch",
        ylabel="Average Loss",
        title="Average Training vs Validation Loss"
    )
    axes[0].legend(loc="upper right")
    axes[0].grid(True)

    # Accuracy vs Weighted F1 score on validation data
    axes[1].plot(n_epochs, accuracy_scores, label="Accuracy", color="blue")
    axes[1].plot(n_epochs, f1_scores, label="Weighted F1 Score", color="red")
    axes[1].set(
        xlabel="Epoch",
        ylabel="",
        title="Validation Accuracy vs Weighted F1 Score"
    )
    axes[1].legend(loc="lower right")
    axes[1].grid(True)

    # Learning rate
    axes[2].plot(n_epochs, lr_changes)
    axes[2].set(
        xlabel="Epoch",
        ylabel="Learning Rate",
        title="Learning Rate Changes"
    )
    axes[2].grid(True)

    plt.suptitle("Training Process", fontsize=16)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

    if show:
        plt.show()