dinosaur_project / src /train_utils.py
lucvantien1211's picture
Upload src folder, which contain python module and script
25ce9a0 verified
'''
This module contains utility functions for training model
'''
# Handling files
import os
# Datetime
from datetime import datetime
# Plotting confusion matrix
from plot_utils import plot_confusion_matrix, plot_training_progress
# Torch
import torch
from torchmetrics.classification import F1Score
# Progression bar
from tqdm import tqdm
def evaluate(model, loss_func, val_loader, device, cm=False):
'''
Evaluate a model's performance through loss, accuracy, weighted F1 score
Args:
model : the model used
loss_func (torch.nn.Module) : the loss function used
val_loader (torch.utils.data.DataLoader): the loader used to load data
device (str) : the device used
cm (bool) : decide to plot confusion matrix
and top-k misclassified classes
table or not, default: False
Return:
val_loss (float): loss
accuracy (float): accuracy
f1_score (float): weighted F1 score
'''
# Set up
model.to(device)
val_loss = 0.0
f1 = F1Score(
task="multiclass",
num_classes=len(val_loader.dataset.classes),
average="weighted"
).to(device)
all_preds = []
all_labels = []
# Evaluate
model.eval()
with torch.no_grad():
for imgs, labels, _ in tqdm(val_loader, desc="Evaluating"):
imgs = imgs.to(device)
labels = labels.to(device)
outputs = model(imgs).logits
loss = loss_func(outputs, labels)
preds = torch.argmax(outputs, dim=1)
val_loss += loss.item()
all_preds.append(preds)
all_labels.append(labels)
# Concatenate predictions and labels
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)
accuracy = (all_preds==all_labels).sum().item() / len(all_labels)
f1_score = f1(all_preds, all_labels).item()
# Plot confusion matrix if required
if cm:
current_time = datetime.now().strftime("%Y%m%dT%H%M%S")
plot_confusion_matrix(
y_true=all_labels.cpu(),
y_pred=all_preds.cpu(),
display_labels=val_loader.dataset.classes,
save_path=f"/dinosaur_project/test_results/{current_time}_evaluation_result.png"
)
return val_loss, accuracy, f1_score
def train_epoch(
model, loss_func, optimizer, train_loader, val_loader,
device, scheduler=None, mix_augment=None
):
'''
Train a model for one epoch
Args:
model : the model used
loss_func (torch.nn.Module) : the loss function used
optimizer (torch.optim.Optimizer) : the optimizer used
train_loader (torch.utils.data.DataLoader) : the loader used to load
training data
val_loader (torch.utils.data.DataLoader) : the loader used to load
validation data
device (str) : the device used
scheduler (torch.optim.lr_scheduler.LRScheduler): learning rate scheduler,
default: None
mix_augment : mixup/cutmix augmentation,
default: None
Return:
avg_train_loss (float): average training loss
avg_val_loss (float) : average validation loss
accuracy (float) : accuracy
f1_score (float) : weighted F1 score
lr (float) : learning rate
'''
# Set up
model.to(device)
train_loss = 0.0
# Train
model.train()
for imgs, labels, _ in tqdm(train_loader, desc="Training"):
imgs = imgs.to(device)
labels = labels.to(device)
# Use mixup/cutmix augmentation if required
if mix_augment:
imgs, labels = mix_augment(imgs, labels)
optimizer.zero_grad()
outputs = model(imgs).logits
loss = loss_func(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Evaluate
val_loss, accuracy, f1_score = evaluate(
model, loss_func, val_loader, device, cm=False
)
# Current learning rate
lr = optimizer.param_groups[0]["lr"]
# Update scheduler if required
if scheduler:
scheduler.step()
# Calculate average train, validation loss
avg_train_loss = train_loss / len(train_loader.dataset)
avg_val_loss = val_loss / len(val_loader.dataset)
print(
f"Average Train Loss: {avg_train_loss:.4f}",
f"Average Validation Loss: {avg_val_loss:.4f}",
f"Accuracy: {accuracy:.4f}",
f"Weighted F1: {f1_score:.4f}",
sep=" | "
)
return avg_train_loss, avg_val_loss, accuracy, f1_score, lr
def train(
model, n_epochs, loss_func, optimizer, train_loader, val_loader,
device, early_stopping_patience=3, scheduler=None, mix_augment=None,
model_dir="/dinosaur_project/model", train_plot_dir="/dinosaur_project/train_process"
):
'''
Train a model for a number of epochs
Args:
model, loss_func, optimizer, train_loader, : same as in train_epoch function
val_loader, device, scheduler, mix_augment
n_epochs (int) : number of epochs to train
early_stopping_patience (int) : number of epoch to wait before
trigger early stopping
model_dir (str) : directory to save model checkpoints,
default: /model
train_plot_dir (str) : directory to save training process
plot, default: /train_process_plot
Return:
None
'''
# Set up
current_time = datetime.now().strftime("%Y%m%dT%H%M%S")
best_model_path = os.path.join(model_dir, f"{current_time}_best_model")
train_process_plot_path = os.path.join(
train_plot_dir, f"{current_time}_train_process.png"
)
avg_train_losses = []
avg_val_losses = []
accuracy_scores = []
f1_scores = []
learning_rates = []
best_f1 = 0.0
best_f1_epoch = 1
early_stopping_cnt = 0
# Train epochs
for i in range(n_epochs):
print(f"Epoch {i+1}:")
avg_train_loss, avg_val_loss, accuracy, f1_score, lr = train_epoch(
model, loss_func, optimizer, train_loader,
val_loader, device, scheduler, mix_augment
)
avg_train_losses.append(avg_train_loss)
avg_val_losses.append(avg_val_loss)
accuracy_scores.append(accuracy)
f1_scores.append(f1_score)
learning_rates.append(lr)
# Check early stopping and save best model
if f1_score <= best_f1:
early_stopping_cnt += 1
else:
best_f1 = f1_score
best_f1_epoch = i+1
early_stopping_cnt = 0
model.save_pretrained(best_model_path)
if early_stopping_cnt == early_stopping_patience:
print(
f"Early stopping triggered. Best weighted F1: {best_f1:.4f},",
f"achieved on epoch {best_f1_epoch}"
)
break
# Plot training process
plot_training_progress(
avg_training_losses=avg_train_losses,
avg_val_losses=avg_val_losses,
accuracy_scores=accuracy_scores,
f1_scores=f1_scores,
lr_changes=learning_rates,
show=False,
save_path=train_process_plot_path
)
print(
f"Best model is saved to {best_model_path}\n",
f"Training Process plot is saved to {train_process_plot_path}"
)