Spaces:
Sleeping
Sleeping
| ''' | |
| This module contains utility functions for training model | |
| ''' | |
| # Handling files | |
| import os | |
| # Datetime | |
| from datetime import datetime | |
| # Plotting confusion matrix | |
| from plot_utils import plot_confusion_matrix, plot_training_progress | |
| # Torch | |
| import torch | |
| from torchmetrics.classification import F1Score | |
| # Progression bar | |
| from tqdm import tqdm | |
| def evaluate(model, loss_func, val_loader, device, cm=False): | |
| ''' | |
| Evaluate a model's performance through loss, accuracy, weighted F1 score | |
| Args: | |
| model : the model used | |
| loss_func (torch.nn.Module) : the loss function used | |
| val_loader (torch.utils.data.DataLoader): the loader used to load data | |
| device (str) : the device used | |
| cm (bool) : decide to plot confusion matrix | |
| and top-k misclassified classes | |
| table or not, default: False | |
| Return: | |
| val_loss (float): loss | |
| accuracy (float): accuracy | |
| f1_score (float): weighted F1 score | |
| ''' | |
| # Set up | |
| model.to(device) | |
| val_loss = 0.0 | |
| f1 = F1Score( | |
| task="multiclass", | |
| num_classes=len(val_loader.dataset.classes), | |
| average="weighted" | |
| ).to(device) | |
| all_preds = [] | |
| all_labels = [] | |
| # Evaluate | |
| model.eval() | |
| with torch.no_grad(): | |
| for imgs, labels, _ in tqdm(val_loader, desc="Evaluating"): | |
| imgs = imgs.to(device) | |
| labels = labels.to(device) | |
| outputs = model(imgs).logits | |
| loss = loss_func(outputs, labels) | |
| preds = torch.argmax(outputs, dim=1) | |
| val_loss += loss.item() | |
| all_preds.append(preds) | |
| all_labels.append(labels) | |
| # Concatenate predictions and labels | |
| all_preds = torch.cat(all_preds) | |
| all_labels = torch.cat(all_labels) | |
| accuracy = (all_preds==all_labels).sum().item() / len(all_labels) | |
| f1_score = f1(all_preds, all_labels).item() | |
| # Plot confusion matrix if required | |
| if cm: | |
| current_time = datetime.now().strftime("%Y%m%dT%H%M%S") | |
| plot_confusion_matrix( | |
| y_true=all_labels.cpu(), | |
| y_pred=all_preds.cpu(), | |
| display_labels=val_loader.dataset.classes, | |
| save_path=f"/dinosaur_project/test_results/{current_time}_evaluation_result.png" | |
| ) | |
| return val_loss, accuracy, f1_score | |
| def train_epoch( | |
| model, loss_func, optimizer, train_loader, val_loader, | |
| device, scheduler=None, mix_augment=None | |
| ): | |
| ''' | |
| Train a model for one epoch | |
| Args: | |
| model : the model used | |
| loss_func (torch.nn.Module) : the loss function used | |
| optimizer (torch.optim.Optimizer) : the optimizer used | |
| train_loader (torch.utils.data.DataLoader) : the loader used to load | |
| training data | |
| val_loader (torch.utils.data.DataLoader) : the loader used to load | |
| validation data | |
| device (str) : the device used | |
| scheduler (torch.optim.lr_scheduler.LRScheduler): learning rate scheduler, | |
| default: None | |
| mix_augment : mixup/cutmix augmentation, | |
| default: None | |
| Return: | |
| avg_train_loss (float): average training loss | |
| avg_val_loss (float) : average validation loss | |
| accuracy (float) : accuracy | |
| f1_score (float) : weighted F1 score | |
| lr (float) : learning rate | |
| ''' | |
| # Set up | |
| model.to(device) | |
| train_loss = 0.0 | |
| # Train | |
| model.train() | |
| for imgs, labels, _ in tqdm(train_loader, desc="Training"): | |
| imgs = imgs.to(device) | |
| labels = labels.to(device) | |
| # Use mixup/cutmix augmentation if required | |
| if mix_augment: | |
| imgs, labels = mix_augment(imgs, labels) | |
| optimizer.zero_grad() | |
| outputs = model(imgs).logits | |
| loss = loss_func(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| # Evaluate | |
| val_loss, accuracy, f1_score = evaluate( | |
| model, loss_func, val_loader, device, cm=False | |
| ) | |
| # Current learning rate | |
| lr = optimizer.param_groups[0]["lr"] | |
| # Update scheduler if required | |
| if scheduler: | |
| scheduler.step() | |
| # Calculate average train, validation loss | |
| avg_train_loss = train_loss / len(train_loader.dataset) | |
| avg_val_loss = val_loss / len(val_loader.dataset) | |
| print( | |
| f"Average Train Loss: {avg_train_loss:.4f}", | |
| f"Average Validation Loss: {avg_val_loss:.4f}", | |
| f"Accuracy: {accuracy:.4f}", | |
| f"Weighted F1: {f1_score:.4f}", | |
| sep=" | " | |
| ) | |
| return avg_train_loss, avg_val_loss, accuracy, f1_score, lr | |
| def train( | |
| model, n_epochs, loss_func, optimizer, train_loader, val_loader, | |
| device, early_stopping_patience=3, scheduler=None, mix_augment=None, | |
| model_dir="/dinosaur_project/model", train_plot_dir="/dinosaur_project/train_process" | |
| ): | |
| ''' | |
| Train a model for a number of epochs | |
| Args: | |
| model, loss_func, optimizer, train_loader, : same as in train_epoch function | |
| val_loader, device, scheduler, mix_augment | |
| n_epochs (int) : number of epochs to train | |
| early_stopping_patience (int) : number of epoch to wait before | |
| trigger early stopping | |
| model_dir (str) : directory to save model checkpoints, | |
| default: /model | |
| train_plot_dir (str) : directory to save training process | |
| plot, default: /train_process_plot | |
| Return: | |
| None | |
| ''' | |
| # Set up | |
| current_time = datetime.now().strftime("%Y%m%dT%H%M%S") | |
| best_model_path = os.path.join(model_dir, f"{current_time}_best_model") | |
| train_process_plot_path = os.path.join( | |
| train_plot_dir, f"{current_time}_train_process.png" | |
| ) | |
| avg_train_losses = [] | |
| avg_val_losses = [] | |
| accuracy_scores = [] | |
| f1_scores = [] | |
| learning_rates = [] | |
| best_f1 = 0.0 | |
| best_f1_epoch = 1 | |
| early_stopping_cnt = 0 | |
| # Train epochs | |
| for i in range(n_epochs): | |
| print(f"Epoch {i+1}:") | |
| avg_train_loss, avg_val_loss, accuracy, f1_score, lr = train_epoch( | |
| model, loss_func, optimizer, train_loader, | |
| val_loader, device, scheduler, mix_augment | |
| ) | |
| avg_train_losses.append(avg_train_loss) | |
| avg_val_losses.append(avg_val_loss) | |
| accuracy_scores.append(accuracy) | |
| f1_scores.append(f1_score) | |
| learning_rates.append(lr) | |
| # Check early stopping and save best model | |
| if f1_score <= best_f1: | |
| early_stopping_cnt += 1 | |
| else: | |
| best_f1 = f1_score | |
| best_f1_epoch = i+1 | |
| early_stopping_cnt = 0 | |
| model.save_pretrained(best_model_path) | |
| if early_stopping_cnt == early_stopping_patience: | |
| print( | |
| f"Early stopping triggered. Best weighted F1: {best_f1:.4f},", | |
| f"achieved on epoch {best_f1_epoch}" | |
| ) | |
| break | |
| # Plot training process | |
| plot_training_progress( | |
| avg_training_losses=avg_train_losses, | |
| avg_val_losses=avg_val_losses, | |
| accuracy_scores=accuracy_scores, | |
| f1_scores=f1_scores, | |
| lr_changes=learning_rates, | |
| show=False, | |
| save_path=train_process_plot_path | |
| ) | |
| print( | |
| f"Best model is saved to {best_model_path}\n", | |
| f"Training Process plot is saved to {train_process_plot_path}" | |
| ) |