# This config file implements the training workflow. It can be combined with multi_gpu_train.yaml to use DDP for # multi-GPU runs. imports: - $import os - $import json - $import datetime - $import torch - $import glob # pull out some constants from MONAI image: $monai.utils.CommonKeys.IMAGE label: $monai.utils.CommonKeys.LABEL pred: $monai.utils.CommonKeys.PRED # multi-gpu values, `rank` will be replaced in a separate script implementing multi-gpu changes rank: 0 # without multi-gpu support consider the process as rank 0 anyway is_not_rank0: '$@rank > 0' # true if not main process, used to disable handlers for other ranks # hyperparameters for you to modify on the command line val_interval: 1 # how often to perform validation after an epoch ckpt_interval: 1 # how often to save a checkpoint after an epoch rand_prob: 0.5 # probability a random transform is applied batch_size: 5 # number of images per batch num_epochs: 10 # number of epochs to train for num_substeps: 1 # how many times to repeatly train with the same batch num_workers: 4 # number of workers to generate batches with learning_rate: 0.001 # initial learning rate num_classes: 4 # number of classes in training data which network should predict device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # define various paths bundle_root: . # root directory of the bundle ckpt_path: $@bundle_root + '/models/model.pt' # checkpoint to load before starting dataset_dir: $@bundle_root + '/data/train_data' # where data is coming from results_dir: $@bundle_root + '/results' # where results are being stored to # a new output directory is chosen using a timestamp for every invocation output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')' # network definition, this could be parameterised by pre-defined values or on the command line network_def: _target_: DenseNet121 spatial_dims: 2 in_channels: 1 out_channels: '@num_classes' network: $@network_def.to(@device) # dataset value, this assumes a JOSN file filled with img##.nii.gz file and label data_json: $@bundle_root + '/data/train_samples.json' # where training data is located and label data_fp: "$open(@data_json,'r', encoding='utf8')" data_dict: "$json.load(@data_fp)" partitions: '$monai.data.partition_dataset(@data_dict, (4, 1), shuffle=True, seed=0)' train_sub: '$@partitions[0]' # train partition val_sub: '$@partitions[1]' # validation partition # these transforms are used for training and validation transform sequences base_transforms: - _target_: LoadImaged keys: '@image' - _target_: EnsureChannelFirstd keys: '@image' # these are the random and regularising transforms used only for training train_transforms: - _target_: RandAxisFlipd keys: '@image' prob: '@rand_prob' - _target_: RandRotate90d keys: '@image' prob: '@rand_prob' - _target_: RandGaussianNoised keys: '@image' prob: '@rand_prob' std: 0.05 - _target_: ScaleIntensityd keys: '@image' # these are used for validation data so no randomness val_transforms: - _target_: ScaleIntensityd keys: '@image' # define the Compose objects for training and validation preprocessing: _target_: Compose transforms: $@base_transforms + @train_transforms val_preprocessing: _target_: Compose transforms: $@base_transforms + @val_transforms # define the datasets for training and validation train_dataset: _target_: Dataset data: '@train_sub' transform: '@preprocessing' val_dataset: _target_: Dataset data: '@val_sub' transform: '@val_preprocessing' # define the dataloaders for training and validation train_dataloader: _target_: ThreadDataLoader # generate data ansynchronously from training dataset: '@train_dataset' batch_size: '@batch_size' repeats: '@num_substeps' num_workers: '@num_workers' val_dataloader: _target_: DataLoader # faster transforms probably won't benefit from threading dataset: '@val_dataset' batch_size: '@batch_size' num_workers: '@num_workers' # Simple CrossEntropy loss configured for multi-class classification lossfn: _target_: torch.nn.CrossEntropyLoss reduction: sum # hyperparameters could be added for other arguments of this class optimizer: _target_: torch.optim.Adam params: $@network.parameters() lr: '@learning_rate' # should be replaced with other inferer types if training process is different for your network inferer: _target_: SimpleInferer # transform to apply to data from network to be suitable for validation postprocessing: _target_: Compose transforms: - _target_: Activationsd keys: '@pred' softmax: true - _target_: AsDiscreted keys: ['@pred', '@label'] argmax: [true, false] to_onehot: '@num_classes' - _target_: ToTensord keys: ['@pred', '@label'] device: '@device' # validation handlers to gather statistics, log these to a file, and save best checkpoint val_handlers: - _target_: StatsHandler name: null # use engine.logger as the Logger object to log to output_transform: '$lambda x: None' - _target_: LogfileHandler # log outputs from the validation engine output_dir: '@output_dir' - _target_: CheckpointSaver _disabled_: '@is_not_rank0' # only need rank 0 to save save_dir: '@output_dir' save_dict: model: '@network' save_interval: 0 # don't save iterations, just when the metric improves save_final: false epoch_level: false save_key_metric: true key_metric_name: val_accuracy # save the checkpoint when this value improves # engine for running validation, ties together objects defined above and has metric definitions evaluator: _target_: SupervisedEvaluator device: '@device' val_data_loader: '@val_dataloader' network: '@network' postprocessing: '@postprocessing' key_val_metric: val_accuracy: _target_: ignite.metrics.Accuracy output_transform: $monai.handlers.from_engine([@pred, @label]) additional_metrics: val_f1: # can have other metrics _target_: ConfusionMatrix metric_name: 'f1 score' output_transform: $monai.handlers.from_engine([@pred, @label]) val_handlers: '@val_handlers' # gathers the loss and validation values for each iteration, referred to by CheckpointSaver so defined separately metriclogger: _target_: MetricLogger evaluator: '@evaluator' handlers: - '@metriclogger' - _target_: CheckpointLoader _disabled_: $not os.path.exists(@ckpt_path) load_path: '@ckpt_path' load_dict: model: '@network' - _target_: ValidationHandler # run validation at the set interval, bridge between trainer and evaluator objects validator: '@evaluator' epoch_level: true interval: '@val_interval' - _target_: CheckpointSaver _disabled_: '@is_not_rank0' # only need rank 0 to save save_dir: '@output_dir' save_dict: # every epoch checkpoint saves the network and the metric logger in a dictionary model: '@network' logger: '@metriclogger' save_interval: '@ckpt_interval' save_final: true epoch_level: true - _target_: StatsHandler name: null # use engine.logger as the Logger object to log to tag_name: train_loss output_transform: $monai.handlers.from_engine(['loss'], first=True) # log loss value - _target_: LogfileHandler # log outputs from the training engine output_dir: '@output_dir' # engine for training, ties values defined above together into the main engine for the training process trainer: _target_: SupervisedTrainer max_epochs: '@num_epochs' device: '@device' train_data_loader: '@train_dataloader' network: '@network' inferer: '@inferer' # unnecessary since SimpleInferer is the default if this isn't provided loss_function: '@lossfn' optimizer: '@optimizer' # postprocessing: '@postprocessing' # uncomment if you have train metrics that need post-processing key_train_metric: null train_handlers: '@handlers' initialize: - "$monai.utils.set_determinism(seed=123)" - "$setattr(torch.backends.cudnn, 'benchmark', True)" run: - "$@trainer.run()"