nomri / models /lightning /no_shared_module.py
samaonline
init
1b34a12
from argparse import ArgumentParser
from typing import Tuple
import torch
import fastmri
from fastmri import transforms
from models.no_shared import NOShared
from models.lightning.mri_module import MriModule
from type_utils import tuple_type
class NOSharedModule(MriModule):
"""
NO-Shared training module.
"""
def __init__(
self,
num_cascades: int = 12,
pools: int = 4,
chans: int = 18,
sens_pools: int = 4,
sens_chans: int = 8,
gno_pools: int = 4,
gno_chans: int = 16,
gno_radius_cutoff: float = 0.02,
gno_kernel_shape: Tuple[int, int] = (6, 7),
radius_cutoff: float = 0.02,
kernel_shape: Tuple[int, int] = (6, 7),
in_shape: Tuple[int, int] = (320, 320),
use_dc_term: bool = True,
lr: float = 0.0003,
lr_step_size: int = 40,
lr_gamma: float = 0.1,
weight_decay: float = 0.0,
**kwargs,
):
"""
Parameters
----------
num_cascades : int
Number of cascades (i.e., layers) for the variational network.
pools : int
Number of downsampling and upsampling layers for the cascade U-Net.
chans : int
Number of channels for the cascade U-Net.
sens_pools : int
Number of downsampling and upsampling layers for the sensitivity map U-Net.
sens_chans : int
Number of channels for the sensitivity map U-Net.
lr : float
Learning rate.
lr_step_size : int
Learning rate step size.
lr_gamma : float
Learning rate gamma decay.
weight_decay : float
Parameter for penalizing weights norm.
"""
super().__init__(**kwargs)
self.save_hyperparameters()
self.num_cascades = num_cascades
self.pools = pools
self.chans = chans
self.sens_pools = sens_pools
self.sens_chans = sens_chans
self.gno_pools = gno_pools
self.gno_chans = gno_chans
self.gno_radius_cutoff = gno_radius_cutoff
self.gno_kernel_shape = gno_kernel_shape
self.radius_cutoff = radius_cutoff
self.kernel_shape = kernel_shape
self.in_shape = in_shape
self.use_dc_term = use_dc_term
self.lr = lr
self.lr_step_size = lr_step_size
self.lr_gamma = lr_gamma
self.weight_decay = weight_decay
self.model = NOShared(
num_cascades=self.num_cascades,
sens_chans=self.sens_chans,
sens_pools=self.sens_pools,
chans=self.chans,
pools=self.pools,
gno_chans=self.gno_chans,
gno_pools=self.gno_pools,
gno_radius_cutoff=self.gno_radius_cutoff,
gno_kernel_shape=self.gno_kernel_shape,
radius_cutoff=radius_cutoff,
kernel_shape=kernel_shape,
in_shape=in_shape,
use_dc_term=use_dc_term,
)
self.criterion = fastmri.SSIMLoss()
self.num_params = sum(p.numel() for p in self.parameters())
def forward(self, masked_kspace, mask, num_low_frequencies):
return self.model(masked_kspace, mask, num_low_frequencies)
def training_step(self, batch, batch_idx):
output = self.forward(
batch.masked_kspace, batch.mask, batch.num_low_frequencies
)
target, output = transforms.center_crop_to_smallest(
batch.target, output
)
loss = self.criterion(
output.unsqueeze(1), target.unsqueeze(1), data_range=batch.max_value
)
self.log("train_loss", loss, on_step=True, on_epoch=True)
self.log("epoch", int(self.current_epoch), on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
dataloaders = self.trainer.val_dataloaders
slug = list(dataloaders.keys())[dataloader_idx]
output = self.forward(
batch.masked_kspace, batch.mask, batch.num_low_frequencies
)
target, output = transforms.center_crop_to_smallest(
batch.target, output
)
loss = self.criterion(
output.unsqueeze(1),
target.unsqueeze(1),
data_range=batch.max_value,
)
return {
"slug": slug,
"fname": batch.fname,
"slice_num": batch.slice_num,
"max_value": batch.max_value,
"output": output,
"target": target,
"val_loss": loss,
}
def configure_optimizers(self):
optim = torch.optim.Adam(
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
scheduler = torch.optim.lr_scheduler.StepLR(
optim, self.lr_step_size, self.lr_gamma
)
return [optim], [scheduler]
@staticmethod
def add_model_specific_args(parent_parser):
"""
Define parameters that only apply to this model
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser = MriModule.add_model_specific_args(parser)
# network params
parser.add_argument(
"--num_cascades",
default=12,
type=int,
help="Number of VarNet cascades",
)
parser.add_argument(
"--pools",
default=4,
type=int,
help="Number of U-Net pooling layers in VarNet blocks",
)
parser.add_argument(
"--chans",
default=18,
type=int,
help="Number of channels for U-Net in VarNet blocks",
)
parser.add_argument(
"--sens_pools",
default=4,
type=int,
help=(
"Number of pooling layers for sense map estimation U-Net in"
" VarNet"
),
)
parser.add_argument(
"--sens_chans",
default=8,
type=float,
help="Number of channels for sense map estimation U-Net in VarNet",
)
parser.add_argument(
"--gno_pools",
default=4,
type=int,
help=("Number of pooling layers for GNO"),
)
parser.add_argument(
"--gno_chans",
default=16,
type=int,
help="Number of channels for GNO",
)
parser.add_argument(
"--gno_radius_cutoff",
default=0.02,
type=float,
help="GNO module radius_cutoff",
)
parser.add_argument(
"--gno_kernel_shape",
default=(6, 7),
type=tuple_type,
help="GNO module kernel_shape. Ex: (6, 7)",
)
parser.add_argument(
"--radius_cutoff",
default=0.02,
type=float,
help="DISCO module radius_cutoff",
)
parser.add_argument(
"--kernel_shape",
default=(6, 7),
type=tuple_type,
help="DISCO module kernel_shape. Ex: (6, 7)",
)
parser.add_argument(
"--in_shape",
default=(320, 320),
type=tuple_type,
help="Spatial dimensions of masked_kspace samples. Ex: (640, 320)",
)
parser.add_argument(
"--use_dc_term",
default=True,
type=bool,
help="Whether to use the DC term in the unrolled iterative update step",
)
# training params (opt)
parser.add_argument(
"--lr", default=0.0003, type=float, help="Adam learning rate"
)
parser.add_argument(
"--lr_step_size",
default=40,
type=int,
help="Epoch at which to decrease step size",
)
parser.add_argument(
"--lr_gamma",
default=0.1,
type=float,
help="Extent to which step size should be decreased",
)
parser.add_argument(
"--weight_decay",
default=0.0,
type=float,
help="Strength of weight decay regularization",
)
return parser