project-monai's picture
Upload segmentation_template version 0.0.3
3f29a93 verified
# This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for
# multi-GPU runs. Many definitions in this file are duplicated across other files for compatibility with MONAI
# Label, eg. network_def, but ideally these would be in a common.yaml file used in conjunction with this one
# or the other config files for testing or inference.
imports:
- $import os
- $import datetime
- $import torch
- $import glob
# pull out some constants from MONAI
image: $monai.utils.CommonKeys.IMAGE
label: $monai.utils.CommonKeys.LABEL
pred: $monai.utils.CommonKeys.PRED
both_keys: ['@image', '@label']
# multi-gpu values, `rank` will be replaced in a separate script implementing multi-gpu changes
rank: 0 # without multi-gpu support consider the process as rank 0 anyway
is_not_rank0: '$@rank > 0' # true if not main process, used to disable handlers for other ranks
# hyperparameters for you to modify on the command line
val_interval: 1 # how often to perform validation after an epoch
ckpt_interval: 1 # how often to save a checkpoint after an epoch
rand_prob: 0.5 # probability a random transform is applied
batch_size: 5 # number of images per batch
num_epochs: 20 # number of epochs to train for
num_substeps: 1 # how many times to repeatly train with the same batch
num_workers: 4 # number of workers to generate batches with
learning_rate: 0.001 # initial learning rate
num_classes: 4 # number of classes in training data which network should predict
device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# define various paths
bundle_root: . # root directory of the bundle
ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting
dataset_dir: $@bundle_root + '/train_data' # where data is coming from
results_dir: $@bundle_root + '/results' # where results are being stored to
# a new output directory is chosen using a timestamp for every invocation
output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')'
# network definition, this could be parameterised by pre-defined values or on the command line
network_def:
_target_: UNet
spatial_dims: 3
in_channels: 1
out_channels: '@num_classes'
channels: [8, 16, 32, 64]
strides: [2, 2, 2]
num_res_units: 2
network: $@network_def.to(@device)
# dataset value, this assumes a directory filled with img##.nii.gz and lbl##.nii.gz files
imgs: '$sorted(glob.glob(@dataset_dir+''/img*.nii.gz''))'
lbls: '$[i.replace(''img'',''lbl'') for i in @imgs]'
all_pairs: '$[{@image: i, @label: l} for i, l in zip(@imgs, @lbls)]'
partitions: '$monai.data.partition_dataset(@all_pairs, (4, 1), shuffle=True, seed=0)'
train_sub: '$@partitions[0]' # train partition
val_sub: '$@partitions[1]' # validation partition
# these transforms are used for training and validation transform sequences
base_transforms:
- _target_: LoadImaged
keys: '@both_keys'
image_only: true
- _target_: EnsureChannelFirstd
keys: '@both_keys'
# these are the random and regularising transforms used only for training
train_transforms:
- _target_: RandAxisFlipd
keys: '@both_keys'
prob: '@rand_prob'
- _target_: RandRotate90d
keys: '@both_keys'
prob: '@rand_prob'
- _target_: RandGaussianNoised
keys: '@image'
prob: '@rand_prob'
std: 0.05
- _target_: ScaleIntensityd
keys: '@image'
# these are used for validation data so no randomness
val_transforms:
- _target_: ScaleIntensityd
keys: '@image'
# define the Compose objects for training and validation
preprocessing:
_target_: Compose
transforms: $@base_transforms + @train_transforms
val_preprocessing:
_target_: Compose
transforms: $@base_transforms + @val_transforms
# define the datasets for training and validation
train_dataset:
_target_: Dataset
data: '@train_sub'
transform: '@preprocessing'
val_dataset:
_target_: Dataset
data: '@val_sub'
transform: '@val_preprocessing'
# define the dataloaders for training and validation
train_dataloader:
_target_: ThreadDataLoader # generate data ansynchronously from training
dataset: '@train_dataset'
batch_size: '@batch_size'
repeats: '@num_substeps'
num_workers: '@num_workers'
val_dataloader:
_target_: DataLoader # faster transforms probably won't benefit from threading
dataset: '@val_dataset'
batch_size: '@batch_size'
num_workers: '@num_workers'
# Simple Dice loss configured for multi-class segmentation, for binary segmentation
# use include_background==True and sigmoid==True instead of these values
lossfn:
_target_: DiceLoss
include_background: true # if your segmentations are relatively small it might help for this to be false
to_onehot_y: true # convert ground truth to one-hot for training
softmax: true # softmax applied to prediction
# hyperparameters could be added for other arguments of this class
optimizer:
_target_: torch.optim.Adam
params: $@network.parameters()
lr: '@learning_rate'
# should be replaced with other inferer types if training process is different for your network
inferer:
_target_: SimpleInferer
# transform to apply to data from network to be suitable for validation
postprocessing:
_target_: Compose
transforms:
- _target_: Activationsd
keys: '@pred'
softmax: true
- _target_: AsDiscreted
keys: ['@pred', '@label']
argmax: [true, false]
to_onehot: '@num_classes'
# validation handlers to gather statistics, log these to a file, and save best checkpoint
val_handlers:
- _target_: StatsHandler
name: null # use engine.logger as the Logger object to log to
output_transform: '$lambda x: None'
- _target_: LogfileHandler # log outputs from the validation engine
output_dir: '@output_dir'
- _target_: CheckpointSaver
_disabled_: '@is_not_rank0' # only need rank 0 to save
save_dir: '@output_dir'
save_dict:
model: '@network'
save_interval: 0 # don't save iterations, just when the metric improves
save_final: false
epoch_level: false
save_key_metric: true
key_metric_name: val_mean_dice # save the checkpoint when this value improves
# engine for running validation, ties together objects defined above and has metric definitions
evaluator:
_target_: SupervisedEvaluator
device: '@device'
val_data_loader: '@val_dataloader'
network: '@network'
postprocessing: '@postprocessing'
key_val_metric:
val_mean_dice:
_target_: MeanDice
include_background: false
output_transform: $monai.handlers.from_engine([@pred, @label])
val_mean_iou:
_target_: MeanIoUHandler
include_background: false
output_transform: $monai.handlers.from_engine([@pred, @label])
additional_metrics:
val_mae: # can have other metrics, MAE not great for segmentation tasks so here just to demo
_target_: MeanAbsoluteError
output_transform: $monai.handlers.from_engine([@pred, @label])
val_handlers: '@val_handlers'
# gathers the loss and validation values for each iteration, referred to by CheckpointSaver so defined separately
metriclogger:
_target_: MetricLogger
evaluator: '@evaluator'
handlers:
- '@metriclogger'
- _target_: CheckpointLoader
_disabled_: $not os.path.exists(@ckpt_path)
load_path: '@ckpt_path'
load_dict:
model: '@network'
- _target_: ValidationHandler # run validation at the set interval, bridge between trainer and evaluator objects
validator: '@evaluator'
epoch_level: true
interval: '@val_interval'
- _target_: CheckpointSaver
_disabled_: '@is_not_rank0' # only need rank 0 to save
save_dir: '@output_dir'
save_dict: # every epoch checkpoint saves the network and the metric logger in a dictionary
model: '@network'
logger: '@metriclogger'
save_interval: '@ckpt_interval'
save_final: true
epoch_level: true
- _target_: StatsHandler
name: null # use engine.logger as the Logger object to log to
tag_name: train_loss
output_transform: $monai.handlers.from_engine(['loss'], first=True) # log loss value
- _target_: LogfileHandler # log outputs from the training engine
output_dir: '@output_dir'
# engine for training, ties values defined above together into the main engine for the training process
trainer:
_target_: SupervisedTrainer
max_epochs: '@num_epochs'
device: '@device'
train_data_loader: '@train_dataloader'
network: '@network'
inferer: '@inferer' # unnecessary since SimpleInferer is the default if this isn't provided
loss_function: '@lossfn'
optimizer: '@optimizer'
# postprocessing: '@postprocessing' # uncomment if you have train metrics that need post-processing
key_train_metric: null
train_handlers: '@handlers'
run:
- $@trainer.run()