Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) Facebook, Inc. and its affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import math | |
import os | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import fastmri | |
from fastmri import transforms | |
from models.unet import Unet | |
class NormUnet(nn.Module): | |
""" | |
Normalized U-Net model. | |
This is the same as a regular U-Net, but with normalization applied to the | |
input before the U-Net. This keeps the values more numerically stable | |
during training. | |
""" | |
def __init__( | |
self, | |
chans: int, | |
num_pools: int, | |
in_chans: int = 2, | |
out_chans: int = 2, | |
drop_prob: float = 0.0, | |
): | |
""" | |
Initialize the VarNet model. | |
Parameters | |
---------- | |
chans : int | |
Number of output channels of the first convolution layer. | |
num_pools : int | |
Number of down-sampling and up-sampling layers. | |
in_chans : int, optional | |
Number of channels in the input to the U-Net model. Default is 2. | |
out_chans : int, optional | |
Number of channels in the output to the U-Net model. Default is 2. | |
drop_prob : float, optional | |
Dropout probability. Default is 0.0. | |
""" | |
super().__init__() | |
self.unet = Unet( | |
in_chans=in_chans, | |
out_chans=out_chans, | |
chans=chans, | |
num_pool_layers=num_pools, | |
drop_prob=drop_prob, | |
) | |
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: | |
b, c, h, w, two = x.shape | |
assert two == 2 | |
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) | |
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: | |
b, c2, h, w = x.shape | |
assert c2 % 2 == 0 | |
c = c2 // 2 | |
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() | |
def norm( | |
self, x: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
# group norm | |
b, c, h, w = x.shape | |
x = x.view(b, 2, c // 2 * h * w) | |
mean = x.mean(dim=2).view(b, 2, 1, 1) | |
std = x.std(dim=2).view(b, 2, 1, 1) | |
x = x.view(b, c, h, w) | |
return (x - mean) / std, mean, std | |
def unnorm( | |
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor | |
) -> torch.Tensor: | |
return x * std + mean | |
def pad( | |
self, x: torch.Tensor | |
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: | |
_, _, h, w = x.shape | |
w_mult = ((w - 1) | 15) + 1 | |
h_mult = ((h - 1) | 15) + 1 | |
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] | |
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] | |
# TODO: fix this type when PyTorch fixes theirs | |
# the documentation lies - this actually takes a list | |
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 | |
# https://github.com/pytorch/pytorch/pull/16949 | |
x = F.pad(x, w_pad + h_pad) | |
return x, (h_pad, w_pad, h_mult, w_mult) | |
def unpad( | |
self, | |
x: torch.Tensor, | |
h_pad: List[int], | |
w_pad: List[int], | |
h_mult: int, | |
w_mult: int, | |
) -> torch.Tensor: | |
return x[ | |
..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1] | |
] | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if not x.shape[-1] == 2: | |
raise ValueError("Last dimension must be 2 for complex.") | |
# get shapes for unet and normalize | |
x = self.complex_to_chan_dim(x) | |
x, mean, std = self.norm(x) | |
x, pad_sizes = self.pad(x) | |
x = self.unet(x) | |
# get shapes back and unnormalize | |
x = self.unpad(x, *pad_sizes) | |
x = self.unnorm(x, mean, std) | |
x = self.chan_complex_to_last_dim(x) | |
return x | |
class SensitivityModel(nn.Module): | |
""" | |
Model for learning sensitivity estimation from k-space data. | |
This model applies an IFFT to multichannel k-space data and then a U-Net | |
to the coil images to estimate coil sensitivities. It can be used with the | |
end-to-end variational network. | |
Input: multi-coil k-space data | |
Output: multi-coil spatial domain sensitivity maps | |
""" | |
def __init__( | |
self, | |
chans: int, | |
num_pools: int, | |
in_chans: int = 2, | |
out_chans: int = 2, | |
drop_prob: float = 0.0, | |
mask_center: bool = True, | |
): | |
""" | |
Parameters | |
---------- | |
chans : int | |
Number of output channels of the first convolution layer. | |
num_pools : int | |
Number of down-sampling and up-sampling layers. | |
in_chans : int, optional | |
Number of channels in the input to the U-Net model. Default is 2. | |
out_chans : int, optional | |
Number of channels in the output to the U-Net model. Default is 2. | |
drop_prob : float, optional | |
Dropout probability. Default is 0.0. | |
mask_center : bool, optional | |
Whether to mask center of k-space for sensitivity map calculation. | |
Default is True. | |
""" | |
super().__init__() | |
self.mask_center = mask_center | |
self.norm_unet = NormUnet( | |
chans, | |
num_pools, | |
in_chans=in_chans, | |
out_chans=out_chans, | |
drop_prob=drop_prob, | |
) | |
def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
b, c, h, w, comp = x.shape | |
return x.view(b * c, 1, h, w, comp), b | |
def batch_chans_to_chan_dim( | |
self, | |
x: torch.Tensor, | |
batch_size: int, | |
) -> torch.Tensor: | |
bc, _, h, w, comp = x.shape | |
c = bc // batch_size | |
return x.view(batch_size, c, h, w, comp) | |
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: | |
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) | |
def get_pad_and_num_low_freqs( | |
self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if num_low_frequencies is None or any( | |
torch.any(t == 0) for t in num_low_frequencies | |
): | |
# get low frequency line locations and mask them out | |
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) | |
cent = squeezed_mask.shape[1] // 2 | |
# running argmin returns the first non-zero | |
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) | |
right = torch.argmin(squeezed_mask[:, cent:], dim=1) | |
num_low_frequencies_tensor = torch.max( | |
2 * torch.min(left, right), torch.ones_like(left) | |
) # force a symmetric center unless 1 | |
else: | |
num_low_frequencies_tensor = num_low_frequencies * torch.ones( | |
mask.shape[0], dtype=mask.dtype, device=mask.device | |
) | |
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2 | |
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long) | |
def forward( | |
self, | |
masked_kspace: torch.Tensor, | |
mask: torch.Tensor, | |
num_low_frequencies: Optional[int] = None, | |
) -> torch.Tensor: | |
if self.mask_center: | |
pad, num_low_freqs = self.get_pad_and_num_low_freqs( | |
mask, num_low_frequencies | |
) | |
masked_kspace = transforms.batched_mask_center( | |
masked_kspace, pad, pad + num_low_freqs | |
) | |
# convert to image space | |
images, batches = self.chans_to_batch_dim(fastmri.ifft2c(masked_kspace)) | |
# estimate sensitivities | |
return self.divide_root_sum_of_squares( | |
self.batch_chans_to_chan_dim(self.norm_unet(images), batches) | |
) | |
class VarNet(nn.Module): | |
""" | |
A full variational network model. | |
This model applies a combination of soft data consistency with a U-Net | |
regularizer. To use non-U-Net regularizers, use VarNetBlock. | |
Input: multi-channel k-space data | |
Output: single-channel RSS reconstructed image | |
""" | |
def __init__( | |
self, | |
num_cascades: int = 12, | |
sens_chans: int = 8, | |
sens_pools: int = 4, | |
chans: int = 18, | |
pools: int = 4, | |
mask_center: bool = True, | |
): | |
""" | |
Parameters | |
---------- | |
num_cascades : int | |
Number of cascades (i.e., layers) for variational network. | |
sens_chans : int | |
Number of channels for sensitivity map U-Net. | |
sens_pools : int | |
Number of downsampling and upsampling layers for sensitivity map U-Net. | |
chans : int | |
Number of channels for cascade U-Net. | |
pools : int | |
Number of downsampling and upsampling layers for cascade U-Net. | |
mask_center : bool | |
Whether to mask center of k-space for sensitivity map calculation. | |
""" | |
super().__init__() | |
self.sens_net = SensitivityModel( | |
chans=sens_chans, | |
num_pools=sens_pools, | |
mask_center=mask_center, | |
) | |
self.cascades = nn.ModuleList( | |
[VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)] | |
) | |
def forward( | |
self, | |
masked_kspace: torch.Tensor, | |
mask: torch.Tensor, | |
num_low_frequencies: Optional[int] = None, | |
) -> torch.Tensor: | |
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) | |
kspace_pred = masked_kspace.clone() | |
for cascade in self.cascades: | |
kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) | |
spatial_pred = fastmri.ifft2c(kspace_pred) | |
# ---------> FIXME: CHANGE FOR MVUE MODE | |
if self.training and os.getenv("MVUE") in ["yes", "1", "true", "True"]: | |
combined_spatial = fastmri.mvue(spatial_pred, sens_maps, dim=1) | |
else: | |
spatial_pred_abs = fastmri.complex_abs(spatial_pred) | |
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1) | |
return combined_spatial | |
class VarNetBlock(nn.Module): | |
""" | |
Model block for end-to-end variational network (refinemnt module) | |
This model applies a combination of soft data consistency with the input | |
model as a regularizer. A series of these blocks can be stacked to form | |
the full variational network. | |
Input: multi-channel k-space data | |
Output: multi-channel k-space data | |
""" | |
def __init__(self, model: nn.Module): | |
""" | |
Parameters | |
---------- | |
model : nn.Module | |
Module for "regularization" component of variational network. | |
""" | |
super().__init__() | |
self.model = model | |
self.dc_weight = nn.Parameter(torch.ones(1)) | |
def sens_expand( | |
self, x: torch.Tensor, sens_maps: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Calculates F (x sens_maps) | |
""" | |
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) | |
def sens_reduce( | |
self, x: torch.Tensor, sens_maps: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Calculates F^{-1}(x) \overline{sens_maps} | |
where \overline{sens_maps} is the element-wise applied complex conjugate | |
""" | |
return fastmri.complex_mul( | |
fastmri.ifft2c(x), fastmri.complex_conj(sens_maps) | |
).sum(dim=1, keepdim=True) | |
def forward( | |
self, | |
current_kspace: torch.Tensor, | |
ref_kspace: torch.Tensor, | |
mask: torch.Tensor, | |
sens_maps: torch.Tensor, | |
) -> torch.Tensor: | |
""" | |
Parameters | |
---------- | |
current_kspace : torch.Tensor | |
The current k-space data (frequency domain data) being processed by the network. | |
ref_kspace : torch.Tensor | |
The reference k-space data (measured data) used for data consistency. | |
mask : torch.Tensor | |
A binary mask indicating the locations in k-space where data consistency should be enforced. | |
sens_maps : torch.Tensor | |
Sensitivity maps for the different coils in parallel imaging. | |
Returns | |
------- | |
torch.Tensor | |
The output k-space data after applying the variational network block. | |
""" | |
""" | |
Model term: | |
- Reduces the current k-space data using the sensitivity maps (inverse Fourier transform followed by element-wise multiplication and summation). | |
- Applies the neural network model to the reduced data. | |
- Expands the output of the model using the sensitivity maps (element-wise multiplication followed by Fourier transform). | |
""" | |
model_term = self.sens_expand( | |
self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps | |
) | |
""" | |
Soft data consistency term: | |
- Calculates the difference between current k-space and reference k-space where the mask is true. | |
- Multiplies this difference by the data consistency weight. | |
""" | |
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) | |
soft_dc = ( | |
torch.where(mask, current_kspace - ref_kspace, zero) | |
* self.dc_weight | |
) | |
# with data consistency term (removed for single cascade experiments) | |
return current_kspace - soft_dc - model_term | |