import math from typing import List, Literal, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import fastmri from fastmri.transforms import ( batched_mask_center, batch_chans_to_chan_dim, chans_to_batch_dim, sens_reduce, sens_expand, ) from models.udno import UDNO class NormUDNO(nn.Module): """ Normalized UDNO model. Inputs are normalized before the UDNO for numerically stable training. """ def __init__( self, chans: int, num_pool_layers: int, radius_cutoff: float, in_shape: Tuple[int, int], kernel_shape: Tuple[int, int], in_chans: int = 2, out_chans: int = 2, drop_prob: float = 0.0, ): """ Initialize the VarNet model. Parameters ---------- chans : int Number of output channels of the first convolution layer. num_pools : int Number of down-sampling and up-sampling layers. in_chans : int, optional Number of channels in the input to the U-Net model. Default is 2. out_chans : int, optional Number of channels in the output to the U-Net model. Default is 2. drop_prob : float, optional Dropout probability. Default is 0.0. """ super().__init__() self.udno = UDNO( in_chans=in_chans, out_chans=out_chans, radius_cutoff=radius_cutoff, chans=chans, num_pool_layers=num_pool_layers, drop_prob=drop_prob, in_shape=in_shape, kernel_shape=kernel_shape, ) def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w, two = x.shape assert two == 2 return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: b, c2, h, w = x.shape assert c2 % 2 == 0 c = c2 // 2 return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() def norm( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # group norm b, c, h, w = x.shape x = x.view(b, 2, c // 2 * h * w) mean = x.mean(dim=2).view(b, 2, 1, 1) std = x.std(dim=2).view(b, 2, 1, 1) x = x.view(b, c, h, w) return (x - mean) / std, mean, std def norm_new( self, x: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # group norm b, c, h, w = x.shape num_groups = 2 assert ( c % num_groups == 0 ), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})." x = x.view(b, num_groups, c // num_groups * h * w) mean = x.mean(dim=2).view(b, num_groups, 1, 1) std = x.std(dim=2).view(b, num_groups, 1, 1) print(x.shape, mean.shape, std.shape) x = x.view(b, c, h, w) mean = ( mean.view(b, num_groups, 1, 1) .repeat(1, c // num_groups, h, w) .view(b, c, h, w) ) std = ( std.view(b, num_groups, 1, 1) .repeat(1, c // num_groups, h, w) .view(b, c, h, w) ) return (x - mean) / std, mean, std def unnorm( self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor ) -> torch.Tensor: return x * std + mean def pad( self, x: torch.Tensor ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: _, _, h, w = x.shape w_mult = ((w - 1) | 15) + 1 h_mult = ((h - 1) | 15) + 1 w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] # TODO: fix this type when PyTorch fixes theirs # the documentation lies - this actually takes a list # https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 # https://github.com/pytorch/pytorch/pull/16949 x = F.pad(x, w_pad + h_pad) return x, (h_pad, w_pad, h_mult, w_mult) def unpad( self, x: torch.Tensor, h_pad: List[int], w_pad: List[int], h_mult: int, w_mult: int, ) -> torch.Tensor: return x[ ..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1] ] def forward(self, x: torch.Tensor) -> torch.Tensor: if not x.shape[-1] == 2: raise ValueError("Last dimension must be 2 for complex.") chans = x.shape[1] if chans == 2: # FIXME: hard coded skip norm/pad temporarily to avoid group norm bug x = self.complex_to_chan_dim(x) x = self.udno(x) return self.chan_complex_to_last_dim(x) # get shapes for unet and normalize x = self.complex_to_chan_dim(x) x, mean, std = self.norm(x) x, pad_sizes = self.pad(x) x = self.udno(x) # get shapes back and unnormalize x = self.unpad(x, *pad_sizes) x = self.unnorm(x, mean, std) x = self.chan_complex_to_last_dim(x) return x class SensitivityModel(nn.Module): """ Learn sensitivity maps """ def __init__( self, chans: int, num_pools: int, radius_cutoff: float, in_shape: Tuple[int, int], kernel_shape: Tuple[int, int], in_chans: int = 2, out_chans: int = 2, drop_prob: float = 0.0, mask_center: bool = True, ): """ Parameters ---------- chans : int Number of output channels of the first convolution layer. num_pools : int Number of down-sampling and up-sampling layers. in_chans : int, optional Number of channels in the input to the U-Net model. Default is 2. out_chans : int, optional Number of channels in the output to the U-Net model. Default is 2. drop_prob : float, optional Dropout probability. Default is 0.0. mask_center : bool, optional Whether to mask center of k-space for sensitivity map calculation. Default is True. """ super().__init__() self.mask_center = mask_center self.norm_udno = NormUDNO( chans, num_pools, radius_cutoff, in_shape, kernel_shape, in_chans=in_chans, out_chans=out_chans, drop_prob=drop_prob, ) def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) def get_pad_and_num_low_freqs( self, mask: torch.Tensor, num_low_frequencies=None ) -> Tuple[torch.Tensor, torch.Tensor]: if num_low_frequencies is None or any( torch.any(t == 0) for t in num_low_frequencies ): # get low frequency line locations and mask them out squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) cent = squeezed_mask.shape[1] // 2 # running argmin returns the first non-zero left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) right = torch.argmin(squeezed_mask[:, cent:], dim=1) num_low_frequencies_tensor = torch.max( 2 * torch.min(left, right), torch.ones_like(left) ) # force a symmetric center unless 1 else: num_low_frequencies_tensor = num_low_frequencies * torch.ones( mask.shape[0], dtype=mask.dtype, device=mask.device ) pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2 return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long) def forward( self, masked_kspace: torch.Tensor, mask: torch.Tensor, num_low_frequencies: Optional[int] = None, ) -> torch.Tensor: if self.mask_center: pad, num_low_freqs = self.get_pad_and_num_low_freqs( mask, num_low_frequencies ) masked_kspace = batched_mask_center( masked_kspace, pad, pad + num_low_freqs ) # convert to image space images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace)) # estimate sensitivities return self.divide_root_sum_of_squares( batch_chans_to_chan_dim(self.norm_udno(images), batches) ) class VarNetBlock(nn.Module): """ Model block for iterative refinement of k-space data. This model applies a combination of soft data consistency with the input model as a regularizer. A series of these blocks can be stacked to form the full variational network. aka Refinement Module in Fig 1 """ def __init__(self, kno: nn.Module, ino: nn.Module): """ Args: model: Module for "regularization" component of variational network. """ super().__init__() self.kno = kno self.ino = ino self.dc_weight = nn.Parameter(torch.ones(1)) def forward( self, current_kspace: torch.Tensor, ref_kspace: torch.Tensor, mask: torch.Tensor, sens_maps: torch.Tensor, use_dc_term: bool = True, ) -> torch.Tensor: """ Args: current_kspace: The current k-space data (frequency domain data) being processed by the network. (torch.Tensor) ref_kspace: Original subsampled k-space data (from which we are reconstrucintg the image (reference k-space). (torch.Tensor) mask: A binary mask indicating the locations in k-space where data consistency should be enforced. (torch.Tensor) sens_maps: Sensitivity maps for the different coils in parallel imaging. (torch.Tensor) """ # model-term see orange box of Fig 1 in E2E-VarNet paper! # multi channel k-space -> single channel image-space b, c, h, w, _ = current_kspace.shape # ======= kNO in measurement (k) space ======== current_kspace, b = chans_to_batch_dim(current_kspace) # reduce current_kspace = self.kno(current_kspace) # inpaint current_kspace = batch_chans_to_chan_dim(current_kspace, b) # expand # ======= iNO in image (i) space ======== reduced_image = sens_reduce(current_kspace, sens_maps) # single channel image-space refined_image = self.ino(reduced_image) # single channel image-space -> multi channel k-space model_term = sens_expand(refined_image, sens_maps) # only use first 15 channels (masked_kspace) in the update # current_kspace = current_kspace[:, :15, :, :, :] if not use_dc_term: return current_kspace - model_term """ Soft data consistency term: - Calculates the difference between current k-space and reference k-space where the mask is true. - Multiplies this difference by the data consistency weight. """ # dc_term: see green box of Fig 1 in E2E-VarNet paper! zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) soft_dc = ( torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight ) return current_kspace - soft_dc - model_term class NOShared(nn.Module): """ Neural Operator model with shared cascade parameters for MRI reconstruction. Uses a variational architecture (iterative updates) with a learned sensitivity model. All operations are resolution invariant employing neural operator modules. """ def __init__( self, num_cascades: int = 12, sens_chans: int = 8, sens_pools: int = 4, chans: int = 18, pools: int = 4, gno_chans: int = 16, gno_pools: int = 4, gno_radius_cutoff: float = 0.02, gno_kernel_shape: Tuple[int, int] = (6, 7), radius_cutoff: float = 0.01, kernel_shape: Tuple[int, int] = (3, 4), in_shape: Tuple[int, int] = (320, 320), mask_center: bool = True, use_dc_term: bool = True, ): """ Parameters ---------- num_cascades : int Number of cascades (i.e., layers) for variational network. sens_chans : int Number of channels for sensitivity map U-Net. sens_pools : int Number of downsampling and upsampling layers for sensitivity map U-Net. chans : int Number of channels for cascade U-Net. pools : int Number of downsampling and upsampling layers for cascade U-Net. mask_center : bool Whether to mask center of k-space for sensitivity map calculation. use_dc_term : bool Whether to use the data consistency term. """ super().__init__() self.num_cascades = num_cascades self.sens_net = SensitivityModel( sens_chans, sens_pools, radius_cutoff, in_shape, kernel_shape, mask_center=False, ) self.kno = NormUDNO( gno_chans, gno_pools, in_shape=in_shape, radius_cutoff=gno_radius_cutoff, kernel_shape=gno_kernel_shape, in_chans=2, out_chans=2, ) self.ino = NormUDNO( chans, pools, radius_cutoff, in_shape, kernel_shape, in_chans=2, out_chans=2, ) self.cascade = VarNetBlock(self.kno, self.ino) self.use_dc_term = use_dc_term def forward( self, masked_kspace: torch.Tensor, mask: torch.Tensor, num_low_frequencies: Optional[int] = None, ) -> torch.Tensor: # (B, C, X, Y, 2) kspace_pred = masked_kspace # iterative update for _ in range(self.num_cascades): # sens model sens_maps = self.sens_net(kspace_pred, mask, num_low_frequencies) # kno + ino (cascade) kspace_pred = self.cascade( kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term ) spatial_pred = fastmri.ifft2c(kspace_pred) spatial_pred_abs = fastmri.complex_abs(spatial_pred) combined_spatial = fastmri.rss(spatial_pred_abs, dim=1) return combined_spatial if __name__ == "__main__": model = NOShared( num_cascades=4, radius_cutoff=0.02, kernel_shape=(6, 7), ) x = torch.rand((2, 15, 320, 320, 2)) o = model(x, x.bool(), None)