Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) Facebook, Inc. and its affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
from argparse import ArgumentParser | |
import torch | |
import fastmri | |
from fastmri import transforms | |
from ..varnet import VarNet | |
import wandb | |
from .mri_module import MriModule | |
class VarNetModule(MriModule): | |
""" | |
VarNet training module. | |
This can be used to train variational networks from the paper: | |
A. Sriram et al. End-to-end variational networks for accelerated MRI | |
reconstruction. In International Conference on Medical Image Computing and | |
Computer-Assisted Intervention, 2020. | |
which was inspired by the earlier paper: | |
K. Hammernik et al. Learning a variational network for reconstruction of | |
accelerated MRI data. Magnetic Resonance inMedicine, 79(6):3055–3071, 2018. | |
""" | |
def __init__( | |
self, | |
num_cascades: int = 12, | |
pools: int = 4, | |
chans: int = 18, | |
sens_pools: int = 4, | |
sens_chans: int = 8, | |
lr: float = 0.0003, | |
lr_step_size: int = 40, | |
lr_gamma: float = 0.1, | |
weight_decay: float = 0.0, | |
**kwargs, | |
): | |
""" | |
Parameters | |
---------- | |
num_cascades : int | |
Number of cascades (i.e., layers) for the variational network. | |
pools : int | |
Number of downsampling and upsampling layers for the cascade U-Net. | |
chans : int | |
Number of channels for the cascade U-Net. | |
sens_pools : int | |
Number of downsampling and upsampling layers for the sensitivity map U-Net. | |
sens_chans : int | |
Number of channels for the sensitivity map U-Net. | |
lr : float | |
Learning rate. | |
lr_step_size : int | |
Learning rate step size. | |
lr_gamma : float | |
Learning rate gamma decay. | |
weight_decay : float | |
Parameter for penalizing weights norm. | |
num_sense_lines : int, optional | |
Number of low-frequency lines to use for sensitivity map computation. | |
Must be even or `None`. Default `None` will automatically compute the number | |
from masks. Default behavior may cause some slices to use more low-frequency | |
lines than others, when used in conjunction with e.g. the EquispacedMaskFunc | |
defaults. To prevent this, either set `num_sense_lines`, or set | |
`skip_low_freqs` and `skip_around_low_freqs` to `True` in the EquispacedMaskFunc. | |
Note that setting this value may lead to undesired behavior when training on | |
multiple accelerations simultaneously. | |
""" | |
super().__init__(**kwargs) | |
self.save_hyperparameters() | |
self.num_cascades = num_cascades | |
self.pools = pools | |
self.chans = chans | |
self.sens_pools = sens_pools | |
self.sens_chans = sens_chans | |
self.lr = lr | |
self.lr_step_size = lr_step_size | |
self.lr_gamma = lr_gamma | |
self.weight_decay = weight_decay | |
self.varnet = VarNet( | |
num_cascades=self.num_cascades, | |
sens_chans=self.sens_chans, | |
sens_pools=self.sens_pools, | |
chans=self.chans, | |
pools=self.pools, | |
) | |
self.criterion = fastmri.SSIMLoss() | |
self.num_params = sum(p.numel() for p in self.parameters()) | |
def forward(self, masked_kspace, mask, num_low_frequencies): | |
return self.varnet(masked_kspace, mask, num_low_frequencies) | |
def training_step(self, batch, batch_idx): | |
output = self.forward( | |
batch.masked_kspace, batch.mask, batch.num_low_frequencies | |
) | |
target, output = transforms.center_crop_to_smallest(batch.target, output) | |
loss = self.criterion( | |
output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value | |
) | |
self.log("train_loss", loss, on_step=True, on_epoch=True) | |
self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True) | |
return loss | |
def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
dataloaders = self.trainer.val_dataloaders | |
slug = list(dataloaders.keys())[dataloader_idx] | |
# breakpoint() | |
output = self.forward( | |
batch.masked_kspace, batch.mask, batch.num_low_frequencies | |
) | |
target, output = transforms.center_crop_to_smallest(batch.target, output) | |
loss = self.criterion( | |
output.unsqueeze(1), | |
target.unsqueeze(1), | |
data_range=batch.max_value, | |
) | |
return { | |
"slug": slug, | |
"fname": batch.fname, | |
"slice_num": batch.slice_num, | |
"max_value": batch.max_value, | |
"output": output, | |
"target": target, | |
"val_loss": loss, | |
} | |
def configure_optimizers(self): | |
optim = torch.optim.Adam( | |
self.parameters(), lr=self.lr, weight_decay=self.weight_decay | |
) | |
scheduler = torch.optim.lr_scheduler.StepLR( | |
optim, self.lr_step_size, self.lr_gamma | |
) | |
return [optim], [scheduler] | |
def add_model_specific_args(parent_parser): # pragma: no-cover | |
""" | |
Define parameters that only apply to this model | |
""" | |
parser = ArgumentParser(parents=[parent_parser], add_help=False) | |
parser = MriModule.add_model_specific_args(parser) | |
# network params | |
parser.add_argument( | |
"--num_cascades", | |
default=12, | |
type=int, | |
help="Number of VarNet cascades", | |
) | |
parser.add_argument( | |
"--pools", | |
default=4, | |
type=int, | |
help="Number of U-Net pooling layers in VarNet blocks", | |
) | |
parser.add_argument( | |
"--chans", | |
default=18, | |
type=int, | |
help="Number of channels for U-Net in VarNet blocks", | |
) | |
parser.add_argument( | |
"--sens_pools", | |
default=4, | |
type=int, | |
help=( | |
"Number of pooling layers for sense map estimation U-Net in" " VarNet" | |
), | |
) | |
parser.add_argument( | |
"--sens_chans", | |
default=8, | |
type=float, | |
help="Number of channels for sense map estimation U-Net in VarNet", | |
) | |
# training params (opt) | |
parser.add_argument( | |
"--lr", default=0.0003, type=float, help="Adam learning rate" | |
) | |
parser.add_argument( | |
"--lr_step_size", | |
default=40, | |
type=int, | |
help="Epoch at which to decrease step size", | |
) | |
parser.add_argument( | |
"--lr_gamma", | |
default=0.1, | |
type=float, | |
help="Extent to which step size should be decreased", | |
) | |
parser.add_argument( | |
"--weight_decay", | |
default=0.0, | |
type=float, | |
help="Strength of weight decay regularization", | |
) | |
return parser | |