nomri / models /no_shared.py
samaonline
init
1b34a12
import math
from typing import List, Literal, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import fastmri
from fastmri.transforms import (
batched_mask_center,
batch_chans_to_chan_dim,
chans_to_batch_dim,
sens_reduce,
sens_expand,
)
from models.udno import UDNO
class NormUDNO(nn.Module):
"""
Normalized UDNO model.
Inputs are normalized before the UDNO for numerically stable training.
"""
def __init__(
self,
chans: int,
num_pool_layers: int,
radius_cutoff: float,
in_shape: Tuple[int, int],
kernel_shape: Tuple[int, int],
in_chans: int = 2,
out_chans: int = 2,
drop_prob: float = 0.0,
):
"""
Initialize the VarNet model.
Parameters
----------
chans : int
Number of output channels of the first convolution layer.
num_pools : int
Number of down-sampling and up-sampling layers.
in_chans : int, optional
Number of channels in the input to the U-Net model. Default is 2.
out_chans : int, optional
Number of channels in the output to the U-Net model. Default is 2.
drop_prob : float, optional
Dropout probability. Default is 0.0.
"""
super().__init__()
self.udno = UDNO(
in_chans=in_chans,
out_chans=out_chans,
radius_cutoff=radius_cutoff,
chans=chans,
num_pool_layers=num_pool_layers,
drop_prob=drop_prob,
in_shape=in_shape,
kernel_shape=kernel_shape,
)
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w, two = x.shape
assert two == 2
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c2, h, w = x.shape
assert c2 % 2 == 0
c = c2 // 2
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
def norm(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# group norm
b, c, h, w = x.shape
x = x.view(b, 2, c // 2 * h * w)
mean = x.mean(dim=2).view(b, 2, 1, 1)
std = x.std(dim=2).view(b, 2, 1, 1)
x = x.view(b, c, h, w)
return (x - mean) / std, mean, std
def norm_new(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# group norm
b, c, h, w = x.shape
num_groups = 2
assert (
c % num_groups == 0
), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})."
x = x.view(b, num_groups, c // num_groups * h * w)
mean = x.mean(dim=2).view(b, num_groups, 1, 1)
std = x.std(dim=2).view(b, num_groups, 1, 1)
print(x.shape, mean.shape, std.shape)
x = x.view(b, c, h, w)
mean = (
mean.view(b, num_groups, 1, 1)
.repeat(1, c // num_groups, h, w)
.view(b, c, h, w)
)
std = (
std.view(b, num_groups, 1, 1)
.repeat(1, c // num_groups, h, w)
.view(b, c, h, w)
)
return (x - mean) / std, mean, std
def unnorm(
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
) -> torch.Tensor:
return x * std + mean
def pad(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
_, _, h, w = x.shape
w_mult = ((w - 1) | 15) + 1
h_mult = ((h - 1) | 15) + 1
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
# TODO: fix this type when PyTorch fixes theirs
# the documentation lies - this actually takes a list
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
# https://github.com/pytorch/pytorch/pull/16949
x = F.pad(x, w_pad + h_pad)
return x, (h_pad, w_pad, h_mult, w_mult)
def unpad(
self,
x: torch.Tensor,
h_pad: List[int],
w_pad: List[int],
h_mult: int,
w_mult: int,
) -> torch.Tensor:
return x[
..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]
]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not x.shape[-1] == 2:
raise ValueError("Last dimension must be 2 for complex.")
chans = x.shape[1]
if chans == 2:
# FIXME: hard coded skip norm/pad temporarily to avoid group norm bug
x = self.complex_to_chan_dim(x)
x = self.udno(x)
return self.chan_complex_to_last_dim(x)
# get shapes for unet and normalize
x = self.complex_to_chan_dim(x)
x, mean, std = self.norm(x)
x, pad_sizes = self.pad(x)
x = self.udno(x)
# get shapes back and unnormalize
x = self.unpad(x, *pad_sizes)
x = self.unnorm(x, mean, std)
x = self.chan_complex_to_last_dim(x)
return x
class SensitivityModel(nn.Module):
"""
Learn sensitivity maps
"""
def __init__(
self,
chans: int,
num_pools: int,
radius_cutoff: float,
in_shape: Tuple[int, int],
kernel_shape: Tuple[int, int],
in_chans: int = 2,
out_chans: int = 2,
drop_prob: float = 0.0,
mask_center: bool = True,
):
"""
Parameters
----------
chans : int
Number of output channels of the first convolution layer.
num_pools : int
Number of down-sampling and up-sampling layers.
in_chans : int, optional
Number of channels in the input to the U-Net model. Default is 2.
out_chans : int, optional
Number of channels in the output to the U-Net model. Default is 2.
drop_prob : float, optional
Dropout probability. Default is 0.0.
mask_center : bool, optional
Whether to mask center of k-space for sensitivity map calculation.
Default is True.
"""
super().__init__()
self.mask_center = mask_center
self.norm_udno = NormUDNO(
chans,
num_pools,
radius_cutoff,
in_shape,
kernel_shape,
in_chans=in_chans,
out_chans=out_chans,
drop_prob=drop_prob,
)
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
def get_pad_and_num_low_freqs(
self, mask: torch.Tensor, num_low_frequencies=None
) -> Tuple[torch.Tensor, torch.Tensor]:
if num_low_frequencies is None or any(
torch.any(t == 0) for t in num_low_frequencies
):
# get low frequency line locations and mask them out
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
cent = squeezed_mask.shape[1] // 2
# running argmin returns the first non-zero
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
right = torch.argmin(squeezed_mask[:, cent:], dim=1)
num_low_frequencies_tensor = torch.max(
2 * torch.min(left, right), torch.ones_like(left)
) # force a symmetric center unless 1
else:
num_low_frequencies_tensor = num_low_frequencies * torch.ones(
mask.shape[0], dtype=mask.dtype, device=mask.device
)
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
def forward(
self,
masked_kspace: torch.Tensor,
mask: torch.Tensor,
num_low_frequencies: Optional[int] = None,
) -> torch.Tensor:
if self.mask_center:
pad, num_low_freqs = self.get_pad_and_num_low_freqs(
mask, num_low_frequencies
)
masked_kspace = batched_mask_center(
masked_kspace, pad, pad + num_low_freqs
)
# convert to image space
images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
# estimate sensitivities
return self.divide_root_sum_of_squares(
batch_chans_to_chan_dim(self.norm_udno(images), batches)
)
class VarNetBlock(nn.Module):
"""
Model block for iterative refinement of k-space data.
This model applies a combination of soft data consistency with the input
model as a regularizer. A series of these blocks can be stacked to form
the full variational network.
aka Refinement Module in Fig 1
"""
def __init__(self, kno: nn.Module, ino: nn.Module):
"""
Args:
model: Module for "regularization" component of variational
network.
"""
super().__init__()
self.kno = kno
self.ino = ino
self.dc_weight = nn.Parameter(torch.ones(1))
def forward(
self,
current_kspace: torch.Tensor,
ref_kspace: torch.Tensor,
mask: torch.Tensor,
sens_maps: torch.Tensor,
use_dc_term: bool = True,
) -> torch.Tensor:
"""
Args:
current_kspace: The current k-space data (frequency domain data)
being processed by the network. (torch.Tensor)
ref_kspace: Original subsampled k-space data (from which we are
reconstrucintg the image (reference k-space). (torch.Tensor)
mask: A binary mask indicating the locations in k-space where
data consistency should be enforced. (torch.Tensor)
sens_maps: Sensitivity maps for the different coils in parallel
imaging. (torch.Tensor)
"""
# model-term see orange box of Fig 1 in E2E-VarNet paper!
# multi channel k-space -> single channel image-space
b, c, h, w, _ = current_kspace.shape
# ======= kNO in measurement (k) space ========
current_kspace, b = chans_to_batch_dim(current_kspace) # reduce
current_kspace = self.kno(current_kspace) # inpaint
current_kspace = batch_chans_to_chan_dim(current_kspace, b) # expand
# ======= iNO in image (i) space ========
reduced_image = sens_reduce(current_kspace, sens_maps)
# single channel image-space
refined_image = self.ino(reduced_image)
# single channel image-space -> multi channel k-space
model_term = sens_expand(refined_image, sens_maps)
# only use first 15 channels (masked_kspace) in the update
# current_kspace = current_kspace[:, :15, :, :, :]
if not use_dc_term:
return current_kspace - model_term
"""
Soft data consistency term:
- Calculates the difference between current k-space and reference k-space where the mask is true.
- Multiplies this difference by the data consistency weight.
"""
# dc_term: see green box of Fig 1 in E2E-VarNet paper!
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
soft_dc = (
torch.where(mask, current_kspace - ref_kspace, zero)
* self.dc_weight
)
return current_kspace - soft_dc - model_term
class NOShared(nn.Module):
"""
Neural Operator model with shared cascade parameters for MRI reconstruction.
Uses a variational architecture (iterative updates) with a learned sensitivity
model. All operations are resolution invariant employing neural operator
modules.
"""
def __init__(
self,
num_cascades: int = 12,
sens_chans: int = 8,
sens_pools: int = 4,
chans: int = 18,
pools: int = 4,
gno_chans: int = 16,
gno_pools: int = 4,
gno_radius_cutoff: float = 0.02,
gno_kernel_shape: Tuple[int, int] = (6, 7),
radius_cutoff: float = 0.01,
kernel_shape: Tuple[int, int] = (3, 4),
in_shape: Tuple[int, int] = (320, 320),
mask_center: bool = True,
use_dc_term: bool = True,
):
"""
Parameters
----------
num_cascades : int
Number of cascades (i.e., layers) for variational network.
sens_chans : int
Number of channels for sensitivity map U-Net.
sens_pools : int
Number of downsampling and upsampling layers for sensitivity map U-Net.
chans : int
Number of channels for cascade U-Net.
pools : int
Number of downsampling and upsampling layers for cascade U-Net.
mask_center : bool
Whether to mask center of k-space for sensitivity map calculation.
use_dc_term : bool
Whether to use the data consistency term.
"""
super().__init__()
self.num_cascades = num_cascades
self.sens_net = SensitivityModel(
sens_chans,
sens_pools,
radius_cutoff,
in_shape,
kernel_shape,
mask_center=False,
)
self.kno = NormUDNO(
gno_chans,
gno_pools,
in_shape=in_shape,
radius_cutoff=gno_radius_cutoff,
kernel_shape=gno_kernel_shape,
in_chans=2,
out_chans=2,
)
self.ino = NormUDNO(
chans,
pools,
radius_cutoff,
in_shape,
kernel_shape,
in_chans=2,
out_chans=2,
)
self.cascade = VarNetBlock(self.kno, self.ino)
self.use_dc_term = use_dc_term
def forward(
self,
masked_kspace: torch.Tensor,
mask: torch.Tensor,
num_low_frequencies: Optional[int] = None,
) -> torch.Tensor:
# (B, C, X, Y, 2)
kspace_pred = masked_kspace
# iterative update
for _ in range(self.num_cascades):
# sens model
sens_maps = self.sens_net(kspace_pred, mask, num_low_frequencies)
# kno + ino (cascade)
kspace_pred = self.cascade(
kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term
)
spatial_pred = fastmri.ifft2c(kspace_pred)
spatial_pred_abs = fastmri.complex_abs(spatial_pred)
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
return combined_spatial
if __name__ == "__main__":
model = NOShared(
num_cascades=4,
radius_cutoff=0.02,
kernel_shape=(6, 7),
)
x = torch.rand((2, 15, 320, 320, 2))
o = model(x, x.bool(), None)