|
|
|
|
|
|
|
|
|
|
|
imports: |
|
- $import os |
|
- $import datetime |
|
- $import torch |
|
- $import glob |
|
|
|
|
|
image: $monai.utils.CommonKeys.IMAGE |
|
label: $monai.utils.CommonKeys.LABEL |
|
pred: $monai.utils.CommonKeys.PRED |
|
both_keys: ['@image', '@label'] |
|
|
|
|
|
rank: 0 |
|
is_not_rank0: '$@rank > 0' |
|
|
|
|
|
val_interval: 1 |
|
ckpt_interval: 1 |
|
rand_prob: 0.5 |
|
batch_size: 5 |
|
num_epochs: 20 |
|
num_substeps: 1 |
|
num_workers: 4 |
|
learning_rate: 0.001 |
|
num_classes: 4 |
|
device: $torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
bundle_root: . |
|
ckpt_path: $@bundle_root + '/models/model.pt' |
|
dataset_dir: $@bundle_root + '/train_data' |
|
results_dir: $@bundle_root + '/results' |
|
|
|
output_dir: '$datetime.datetime.now().strftime(@results_dir + ''/output_%y%m%d_%H%M%S'')' |
|
|
|
|
|
network_def: |
|
_target_: UNet |
|
spatial_dims: 3 |
|
in_channels: 1 |
|
out_channels: '@num_classes' |
|
channels: [8, 16, 32, 64] |
|
strides: [2, 2, 2] |
|
num_res_units: 2 |
|
network: $@network_def.to(@device) |
|
|
|
|
|
imgs: '$sorted(glob.glob(@dataset_dir+''/img*.nii.gz''))' |
|
lbls: '$[i.replace(''img'',''lbl'') for i in @imgs]' |
|
all_pairs: '$[{@image: i, @label: l} for i, l in zip(@imgs, @lbls)]' |
|
partitions: '$monai.data.partition_dataset(@all_pairs, (4, 1), shuffle=True, seed=0)' |
|
train_sub: '$@partitions[0]' |
|
val_sub: '$@partitions[1]' |
|
|
|
|
|
base_transforms: |
|
- _target_: LoadImaged |
|
keys: '@both_keys' |
|
image_only: true |
|
- _target_: EnsureChannelFirstd |
|
keys: '@both_keys' |
|
|
|
|
|
train_transforms: |
|
- _target_: RandAxisFlipd |
|
keys: '@both_keys' |
|
prob: '@rand_prob' |
|
- _target_: RandRotate90d |
|
keys: '@both_keys' |
|
prob: '@rand_prob' |
|
- _target_: RandGaussianNoised |
|
keys: '@image' |
|
prob: '@rand_prob' |
|
std: 0.05 |
|
- _target_: ScaleIntensityd |
|
keys: '@image' |
|
|
|
|
|
val_transforms: |
|
- _target_: ScaleIntensityd |
|
keys: '@image' |
|
|
|
|
|
|
|
preprocessing: |
|
_target_: Compose |
|
transforms: $@base_transforms + @train_transforms |
|
|
|
val_preprocessing: |
|
_target_: Compose |
|
transforms: $@base_transforms + @val_transforms |
|
|
|
|
|
|
|
train_dataset: |
|
_target_: Dataset |
|
data: '@train_sub' |
|
transform: '@preprocessing' |
|
|
|
val_dataset: |
|
_target_: Dataset |
|
data: '@val_sub' |
|
transform: '@val_preprocessing' |
|
|
|
|
|
|
|
train_dataloader: |
|
_target_: ThreadDataLoader |
|
dataset: '@train_dataset' |
|
batch_size: '@batch_size' |
|
repeats: '@num_substeps' |
|
num_workers: '@num_workers' |
|
|
|
val_dataloader: |
|
_target_: DataLoader |
|
dataset: '@val_dataset' |
|
batch_size: '@batch_size' |
|
num_workers: '@num_workers' |
|
|
|
|
|
|
|
lossfn: |
|
_target_: DiceLoss |
|
include_background: true |
|
to_onehot_y: true |
|
softmax: true |
|
|
|
|
|
optimizer: |
|
_target_: torch.optim.Adam |
|
params: $@network.parameters() |
|
lr: '@learning_rate' |
|
|
|
|
|
inferer: |
|
_target_: SimpleInferer |
|
|
|
|
|
postprocessing: |
|
_target_: Compose |
|
transforms: |
|
- _target_: Activationsd |
|
keys: '@pred' |
|
softmax: true |
|
- _target_: AsDiscreted |
|
keys: ['@pred', '@label'] |
|
argmax: [true, false] |
|
to_onehot: '@num_classes' |
|
|
|
|
|
val_handlers: |
|
- _target_: StatsHandler |
|
name: null |
|
output_transform: '$lambda x: None' |
|
- _target_: LogfileHandler |
|
output_dir: '@output_dir' |
|
- _target_: CheckpointSaver |
|
_disabled_: '@is_not_rank0' |
|
save_dir: '@output_dir' |
|
save_dict: |
|
model: '@network' |
|
save_interval: 0 |
|
save_final: false |
|
epoch_level: false |
|
save_key_metric: true |
|
key_metric_name: val_mean_dice |
|
|
|
|
|
evaluator: |
|
_target_: SupervisedEvaluator |
|
device: '@device' |
|
val_data_loader: '@val_dataloader' |
|
network: '@network' |
|
postprocessing: '@postprocessing' |
|
key_val_metric: |
|
val_mean_dice: |
|
_target_: MeanDice |
|
include_background: false |
|
output_transform: $monai.handlers.from_engine([@pred, @label]) |
|
val_mean_iou: |
|
_target_: MeanIoUHandler |
|
include_background: false |
|
output_transform: $monai.handlers.from_engine([@pred, @label]) |
|
additional_metrics: |
|
val_mae: |
|
_target_: MeanAbsoluteError |
|
output_transform: $monai.handlers.from_engine([@pred, @label]) |
|
val_handlers: '@val_handlers' |
|
|
|
|
|
metriclogger: |
|
_target_: MetricLogger |
|
evaluator: '@evaluator' |
|
|
|
handlers: |
|
- '@metriclogger' |
|
- _target_: CheckpointLoader |
|
_disabled_: $not os.path.exists(@ckpt_path) |
|
load_path: '@ckpt_path' |
|
load_dict: |
|
model: '@network' |
|
- _target_: ValidationHandler |
|
validator: '@evaluator' |
|
epoch_level: true |
|
interval: '@val_interval' |
|
- _target_: CheckpointSaver |
|
_disabled_: '@is_not_rank0' |
|
save_dir: '@output_dir' |
|
save_dict: |
|
model: '@network' |
|
logger: '@metriclogger' |
|
save_interval: '@ckpt_interval' |
|
save_final: true |
|
epoch_level: true |
|
- _target_: StatsHandler |
|
name: null |
|
tag_name: train_loss |
|
output_transform: $monai.handlers.from_engine(['loss'], first=True) |
|
- _target_: LogfileHandler |
|
output_dir: '@output_dir' |
|
|
|
|
|
trainer: |
|
_target_: SupervisedTrainer |
|
max_epochs: '@num_epochs' |
|
device: '@device' |
|
train_data_loader: '@train_dataloader' |
|
network: '@network' |
|
inferer: '@inferer' |
|
loss_function: '@lossfn' |
|
optimizer: '@optimizer' |
|
|
|
key_train_metric: null |
|
train_handlers: '@handlers' |
|
|
|
run: |
|
- $@trainer.run() |
|
|