nomri / models /varnet.py
samaonline
init
1b34a12
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import math
import os
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import fastmri
from fastmri import transforms
from models.unet import Unet
class NormUnet(nn.Module):
"""
Normalized U-Net model.
This is the same as a regular U-Net, but with normalization applied to the
input before the U-Net. This keeps the values more numerically stable
during training.
"""
def __init__(
self,
chans: int,
num_pools: int,
in_chans: int = 2,
out_chans: int = 2,
drop_prob: float = 0.0,
):
"""
Initialize the VarNet model.
Parameters
----------
chans : int
Number of output channels of the first convolution layer.
num_pools : int
Number of down-sampling and up-sampling layers.
in_chans : int, optional
Number of channels in the input to the U-Net model. Default is 2.
out_chans : int, optional
Number of channels in the output to the U-Net model. Default is 2.
drop_prob : float, optional
Dropout probability. Default is 0.0.
"""
super().__init__()
self.unet = Unet(
in_chans=in_chans,
out_chans=out_chans,
chans=chans,
num_pool_layers=num_pools,
drop_prob=drop_prob,
)
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w, two = x.shape
assert two == 2
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w)
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor:
b, c2, h, w = x.shape
assert c2 % 2 == 0
c = c2 // 2
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous()
def norm(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# group norm
b, c, h, w = x.shape
x = x.view(b, 2, c // 2 * h * w)
mean = x.mean(dim=2).view(b, 2, 1, 1)
std = x.std(dim=2).view(b, 2, 1, 1)
x = x.view(b, c, h, w)
return (x - mean) / std, mean, std
def unnorm(
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor
) -> torch.Tensor:
return x * std + mean
def pad(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]:
_, _, h, w = x.shape
w_mult = ((w - 1) | 15) + 1
h_mult = ((h - 1) | 15) + 1
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)]
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)]
# TODO: fix this type when PyTorch fixes theirs
# the documentation lies - this actually takes a list
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457
# https://github.com/pytorch/pytorch/pull/16949
x = F.pad(x, w_pad + h_pad)
return x, (h_pad, w_pad, h_mult, w_mult)
def unpad(
self,
x: torch.Tensor,
h_pad: List[int],
w_pad: List[int],
h_mult: int,
w_mult: int,
) -> torch.Tensor:
return x[
..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]
]
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not x.shape[-1] == 2:
raise ValueError("Last dimension must be 2 for complex.")
# get shapes for unet and normalize
x = self.complex_to_chan_dim(x)
x, mean, std = self.norm(x)
x, pad_sizes = self.pad(x)
x = self.unet(x)
# get shapes back and unnormalize
x = self.unpad(x, *pad_sizes)
x = self.unnorm(x, mean, std)
x = self.chan_complex_to_last_dim(x)
return x
class SensitivityModel(nn.Module):
"""
Model for learning sensitivity estimation from k-space data.
This model applies an IFFT to multichannel k-space data and then a U-Net
to the coil images to estimate coil sensitivities. It can be used with the
end-to-end variational network.
Input: multi-coil k-space data
Output: multi-coil spatial domain sensitivity maps
"""
def __init__(
self,
chans: int,
num_pools: int,
in_chans: int = 2,
out_chans: int = 2,
drop_prob: float = 0.0,
mask_center: bool = True,
):
"""
Parameters
----------
chans : int
Number of output channels of the first convolution layer.
num_pools : int
Number of down-sampling and up-sampling layers.
in_chans : int, optional
Number of channels in the input to the U-Net model. Default is 2.
out_chans : int, optional
Number of channels in the output to the U-Net model. Default is 2.
drop_prob : float, optional
Dropout probability. Default is 0.0.
mask_center : bool, optional
Whether to mask center of k-space for sensitivity map calculation.
Default is True.
"""
super().__init__()
self.mask_center = mask_center
self.norm_unet = NormUnet(
chans,
num_pools,
in_chans=in_chans,
out_chans=out_chans,
drop_prob=drop_prob,
)
def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
b, c, h, w, comp = x.shape
return x.view(b * c, 1, h, w, comp), b
def batch_chans_to_chan_dim(
self,
x: torch.Tensor,
batch_size: int,
) -> torch.Tensor:
bc, _, h, w, comp = x.shape
c = bc // batch_size
return x.view(batch_size, c, h, w, comp)
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor:
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1)
def get_pad_and_num_low_freqs(
self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if num_low_frequencies is None or any(
torch.any(t == 0) for t in num_low_frequencies
):
# get low frequency line locations and mask them out
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8)
cent = squeezed_mask.shape[1] // 2
# running argmin returns the first non-zero
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1)
right = torch.argmin(squeezed_mask[:, cent:], dim=1)
num_low_frequencies_tensor = torch.max(
2 * torch.min(left, right), torch.ones_like(left)
) # force a symmetric center unless 1
else:
num_low_frequencies_tensor = num_low_frequencies * torch.ones(
mask.shape[0], dtype=mask.dtype, device=mask.device
)
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long)
def forward(
self,
masked_kspace: torch.Tensor,
mask: torch.Tensor,
num_low_frequencies: Optional[int] = None,
) -> torch.Tensor:
if self.mask_center:
pad, num_low_freqs = self.get_pad_and_num_low_freqs(
mask, num_low_frequencies
)
masked_kspace = transforms.batched_mask_center(
masked_kspace, pad, pad + num_low_freqs
)
# convert to image space
images, batches = self.chans_to_batch_dim(fastmri.ifft2c(masked_kspace))
# estimate sensitivities
return self.divide_root_sum_of_squares(
self.batch_chans_to_chan_dim(self.norm_unet(images), batches)
)
class VarNet(nn.Module):
"""
A full variational network model.
This model applies a combination of soft data consistency with a U-Net
regularizer. To use non-U-Net regularizers, use VarNetBlock.
Input: multi-channel k-space data
Output: single-channel RSS reconstructed image
"""
def __init__(
self,
num_cascades: int = 12,
sens_chans: int = 8,
sens_pools: int = 4,
chans: int = 18,
pools: int = 4,
mask_center: bool = True,
):
"""
Parameters
----------
num_cascades : int
Number of cascades (i.e., layers) for variational network.
sens_chans : int
Number of channels for sensitivity map U-Net.
sens_pools : int
Number of downsampling and upsampling layers for sensitivity map U-Net.
chans : int
Number of channels for cascade U-Net.
pools : int
Number of downsampling and upsampling layers for cascade U-Net.
mask_center : bool
Whether to mask center of k-space for sensitivity map calculation.
"""
super().__init__()
self.sens_net = SensitivityModel(
chans=sens_chans,
num_pools=sens_pools,
mask_center=mask_center,
)
self.cascades = nn.ModuleList(
[VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)]
)
def forward(
self,
masked_kspace: torch.Tensor,
mask: torch.Tensor,
num_low_frequencies: Optional[int] = None,
) -> torch.Tensor:
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies)
kspace_pred = masked_kspace.clone()
for cascade in self.cascades:
kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps)
spatial_pred = fastmri.ifft2c(kspace_pred)
# ---------> FIXME: CHANGE FOR MVUE MODE
if self.training and os.getenv("MVUE") in ["yes", "1", "true", "True"]:
combined_spatial = fastmri.mvue(spatial_pred, sens_maps, dim=1)
else:
spatial_pred_abs = fastmri.complex_abs(spatial_pred)
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1)
return combined_spatial
class VarNetBlock(nn.Module):
"""
Model block for end-to-end variational network (refinemnt module)
This model applies a combination of soft data consistency with the input
model as a regularizer. A series of these blocks can be stacked to form
the full variational network.
Input: multi-channel k-space data
Output: multi-channel k-space data
"""
def __init__(self, model: nn.Module):
"""
Parameters
----------
model : nn.Module
Module for "regularization" component of variational network.
"""
super().__init__()
self.model = model
self.dc_weight = nn.Parameter(torch.ones(1))
def sens_expand(
self, x: torch.Tensor, sens_maps: torch.Tensor
) -> torch.Tensor:
"""
Calculates F (x sens_maps)
"""
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps))
def sens_reduce(
self, x: torch.Tensor, sens_maps: torch.Tensor
) -> torch.Tensor:
"""
Calculates F^{-1}(x) \overline{sens_maps}
where \overline{sens_maps} is the element-wise applied complex conjugate
"""
return fastmri.complex_mul(
fastmri.ifft2c(x), fastmri.complex_conj(sens_maps)
).sum(dim=1, keepdim=True)
def forward(
self,
current_kspace: torch.Tensor,
ref_kspace: torch.Tensor,
mask: torch.Tensor,
sens_maps: torch.Tensor,
) -> torch.Tensor:
"""
Parameters
----------
current_kspace : torch.Tensor
The current k-space data (frequency domain data) being processed by the network.
ref_kspace : torch.Tensor
The reference k-space data (measured data) used for data consistency.
mask : torch.Tensor
A binary mask indicating the locations in k-space where data consistency should be enforced.
sens_maps : torch.Tensor
Sensitivity maps for the different coils in parallel imaging.
Returns
-------
torch.Tensor
The output k-space data after applying the variational network block.
"""
"""
Model term:
- Reduces the current k-space data using the sensitivity maps (inverse Fourier transform followed by element-wise multiplication and summation).
- Applies the neural network model to the reduced data.
- Expands the output of the model using the sensitivity maps (element-wise multiplication followed by Fourier transform).
"""
model_term = self.sens_expand(
self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps
)
"""
Soft data consistency term:
- Calculates the difference between current k-space and reference k-space where the mask is true.
- Multiplies this difference by the data consistency weight.
"""
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
soft_dc = (
torch.where(mask, current_kspace - ref_kspace, zero)
* self.dc_weight
)
# with data consistency term (removed for single cascade experiments)
return current_kspace - soft_dc - model_term