Spaces:
Running
on
Zero
Running
on
Zero
""" | |
NO Varnet WITHOUT KNO for ablation | |
""" | |
import math | |
from typing import Iterable, List, Literal, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import fastmri | |
from fastmri import transforms | |
from fastmri.datasets import SliceDatasetLMDB, SliceSample | |
from models.udno import UDNO | |
def sens_expand(x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: | |
""" | |
Calculates F (x sens_maps) | |
Parameters | |
---------- | |
x : ndarray | |
Single-channel image of shape (..., H, W, 2) | |
sens_maps : ndarray | |
Sensitivity maps (image space) | |
Returns | |
------- | |
ndarray | |
Result of the operation F (x sens_maps) | |
""" | |
return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) | |
def sens_reduce(k: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: | |
""" | |
Calculates F^{-1}(k) * conj(sens_maps) | |
where conj(sens_maps) is the element-wise applied complex conjugate | |
Parameters | |
---------- | |
k : ndarray | |
Multi-channel k-space of shape (B, C, H, W, 2) | |
sens_maps : ndarray | |
Sensitivity maps (image space) | |
Returns | |
------- | |
ndarray | |
Result of the operation F^{-1}(k) * conj(sens_maps) | |
""" | |
return fastmri.complex_mul( | |
fastmri.ifft2c(k), fastmri.complex_conj(sens_maps) | |
).sum(dim=1, keepdim=True) | |
def chans_to_batch_dim(x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
"""Reshapes batched multi-channel samples into multiple single channel samples. | |
Parameters | |
---------- | |
x : torch.Tensor | |
x has shape (b, c, h, w, 2) | |
Returns | |
------- | |
Tuple[torch.Tensor, int] | |
tensor of shape (b * c, 1, h, w, 2), b | |
""" | |
b, c, h, w, comp = x.shape | |
return x.view(b * c, 1, h, w, comp), b | |
def batch_chans_to_chan_dim(x: torch.Tensor, batch_size: int) -> torch.Tensor: | |
"""Reshapes batched independent samples into original multi-channel samples. | |
Parameters | |
---------- | |
x : torch.Tensor | |
tensor of shape (b * c, 1, h, w, 2) | |
batch_size : int | |
batch size | |
Returns | |
------- | |
torch.Tensor | |
original multi-channel tensor of shape (b, c, h, w, 2) | |
""" | |
bc, _, h, w, comp = x.shape | |
c = bc // batch_size | |
return x.view(batch_size, c, h, w, comp) | |
class NormUDNO(nn.Module): | |
""" | |
Normalized UDNO model. | |
Inputs are normalized before the UDNO for numerically stable training. | |
""" | |
def __init__( | |
self, | |
chans: int, | |
num_pool_layers: int, | |
radius_cutoff: float, | |
in_shape: Tuple[int, int], | |
kernel_shape: Tuple[int, int], | |
in_chans: int = 2, | |
out_chans: int = 2, | |
drop_prob: float = 0.0, | |
): | |
""" | |
Initialize the VarNet model. | |
Parameters | |
---------- | |
chans : int | |
Number of output channels of the first convolution layer. | |
num_pools : int | |
Number of down-sampling and up-sampling layers. | |
in_chans : int, optional | |
Number of channels in the input to the U-Net model. Default is 2. | |
out_chans : int, optional | |
Number of channels in the output to the U-Net model. Default is 2. | |
drop_prob : float, optional | |
Dropout probability. Default is 0.0. | |
""" | |
super().__init__() | |
self.udno = UDNO( | |
in_chans=in_chans, | |
out_chans=out_chans, | |
radius_cutoff=radius_cutoff, | |
chans=chans, | |
num_pool_layers=num_pool_layers, | |
drop_prob=drop_prob, | |
in_shape=in_shape, | |
kernel_shape=kernel_shape, | |
) | |
def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: | |
b, c, h, w, two = x.shape | |
assert two == 2 | |
return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) | |
def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: | |
b, c2, h, w = x.shape | |
assert c2 % 2 == 0 | |
c = c2 // 2 | |
return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() | |
def norm( | |
self, x: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
# group norm | |
b, c, h, w = x.shape | |
x = x.view(b, 2, c // 2 * h * w) | |
mean = x.mean(dim=2).view(b, 2, 1, 1) | |
std = x.std(dim=2).view(b, 2, 1, 1) | |
x = x.view(b, c, h, w) | |
return (x - mean) / std, mean, std | |
def norm_new( | |
self, x: torch.Tensor | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
# FIXME: not working, wip | |
# group norm | |
b, c, h, w = x.shape | |
num_groups = 2 | |
assert ( | |
c % num_groups == 0 | |
), f"Number of channels ({c}) must be divisible by number of groups ({num_groups})." | |
x = x.view(b, num_groups, c // num_groups * h * w) | |
mean = x.mean(dim=2).view(b, num_groups, 1, 1) | |
std = x.std(dim=2).view(b, num_groups, 1, 1) | |
print(x.shape, mean.shape, std.shape) | |
x = x.view(b, c, h, w) | |
mean = ( | |
mean.view(b, num_groups, 1, 1) | |
.repeat(1, c // num_groups, h, w) | |
.view(b, c, h, w) | |
) | |
std = ( | |
std.view(b, num_groups, 1, 1) | |
.repeat(1, c // num_groups, h, w) | |
.view(b, c, h, w) | |
) | |
return (x - mean) / std, mean, std | |
def unnorm( | |
self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor | |
) -> torch.Tensor: | |
return x * std + mean | |
def pad( | |
self, x: torch.Tensor | |
) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: | |
_, _, h, w = x.shape | |
w_mult = ((w - 1) | 15) + 1 | |
h_mult = ((h - 1) | 15) + 1 | |
w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] | |
h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] | |
# TODO: fix this type when PyTorch fixes theirs | |
# the documentation lies - this actually takes a list | |
# https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3457 | |
# https://github.com/pytorch/pytorch/pull/16949 | |
x = F.pad(x, w_pad + h_pad) | |
return x, (h_pad, w_pad, h_mult, w_mult) | |
def unpad( | |
self, | |
x: torch.Tensor, | |
h_pad: List[int], | |
w_pad: List[int], | |
h_mult: int, | |
w_mult: int, | |
) -> torch.Tensor: | |
return x[ | |
..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1] | |
] | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if not x.shape[-1] == 2: | |
raise ValueError("Last dimension must be 2 for complex.") | |
chans = x.shape[1] | |
if chans == 2: | |
# FIXME: hard coded skip norm/pad temporarily to avoid group norm bug | |
x = self.complex_to_chan_dim(x) | |
x = self.udno(x) | |
return self.chan_complex_to_last_dim(x) | |
# get shapes for unet and normalize | |
x = self.complex_to_chan_dim(x) | |
x, mean, std = self.norm(x) | |
x, pad_sizes = self.pad(x) | |
x = self.udno(x) | |
# get shapes back and unnormalize | |
x = self.unpad(x, *pad_sizes) | |
x = self.unnorm(x, mean, std) | |
x = self.chan_complex_to_last_dim(x) | |
return x | |
class SensitivityModel(nn.Module): | |
""" | |
Learn sensitivity maps | |
""" | |
def __init__( | |
self, | |
chans: int, | |
num_pools: int, | |
radius_cutoff: float, | |
in_shape: Tuple[int, int], | |
kernel_shape: Tuple[int, int], | |
in_chans: int = 2, | |
out_chans: int = 2, | |
drop_prob: float = 0.0, | |
mask_center: bool = True, | |
): | |
""" | |
Parameters | |
---------- | |
chans : int | |
Number of output channels of the first convolution layer. | |
num_pools : int | |
Number of down-sampling and up-sampling layers. | |
in_chans : int, optional | |
Number of channels in the input to the U-Net model. Default is 2. | |
out_chans : int, optional | |
Number of channels in the output to the U-Net model. Default is 2. | |
drop_prob : float, optional | |
Dropout probability. Default is 0.0. | |
mask_center : bool, optional | |
Whether to mask center of k-space for sensitivity map calculation. | |
Default is True. | |
""" | |
super().__init__() | |
self.mask_center = mask_center | |
self.norm_udno = NormUDNO( | |
chans, | |
num_pools, | |
radius_cutoff, | |
in_shape, | |
kernel_shape, | |
in_chans=in_chans, | |
out_chans=out_chans, | |
drop_prob=drop_prob, | |
) | |
def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: | |
return x / fastmri.rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) | |
def get_pad_and_num_low_freqs( | |
self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if num_low_frequencies is None or (isinstance(num_low_frequencies, Iterable) and any( | |
torch.any(t == 0) for t in num_low_frequencies | |
)): | |
# get low frequency line locations and mask them out | |
squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) | |
cent = squeezed_mask.shape[1] // 2 | |
# running argmin returns the first non-zero | |
left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) | |
right = torch.argmin(squeezed_mask[:, cent:], dim=1) | |
num_low_frequencies_tensor = torch.max( | |
2 * torch.min(left, right), torch.ones_like(left) | |
) # force a symmetric center unless 1 | |
else: | |
num_low_frequencies_tensor = num_low_frequencies * torch.ones( | |
mask.shape[0], dtype=mask.dtype, device=mask.device | |
) | |
pad = (mask.shape[-2] - num_low_frequencies_tensor + 1) // 2 | |
return pad.type(torch.long), num_low_frequencies_tensor.type(torch.long) | |
def forward( | |
self, | |
masked_kspace: torch.Tensor, | |
mask: torch.Tensor, | |
num_low_frequencies: Optional[int] = None, | |
) -> torch.Tensor: | |
if self.mask_center: | |
pad, num_low_freqs = self.get_pad_and_num_low_freqs( | |
mask, num_low_frequencies | |
) | |
masked_kspace = transforms.batched_mask_center( | |
masked_kspace, pad, pad + num_low_freqs | |
) | |
# convert to image space | |
images, batches = chans_to_batch_dim(fastmri.ifft2c(masked_kspace)) | |
# estimate sensitivities | |
return self.divide_root_sum_of_squares( | |
batch_chans_to_chan_dim(self.norm_udno(images), batches) | |
) | |
class VarNetBlock(nn.Module): | |
""" | |
Model block for iterative refinement of k-space data. | |
This model applies a combination of soft data consistency with the input | |
model as a regularizer. A series of these blocks can be stacked to form | |
the full variational network. | |
aka Refinement Module in Fig 1 | |
""" | |
def __init__(self, model: nn.Module): | |
""" | |
Args: | |
model: Module for "regularization" component of variational | |
network. | |
""" | |
super().__init__() | |
self.model = model | |
self.dc_weight = nn.Parameter(torch.ones(1)) | |
def forward( | |
self, | |
current_kspace: torch.Tensor, | |
ref_kspace: torch.Tensor, | |
mask: torch.Tensor, | |
sens_maps: torch.Tensor, | |
use_dc_term: bool = True, | |
) -> torch.Tensor: | |
""" | |
Args: | |
current_kspace: The current k-space data (frequency domain data) | |
being processed by the network. (torch.Tensor) | |
ref_kspace: Original subsampled k-space data (from which we are | |
reconstrucintg the image (reference k-space). (torch.Tensor) | |
mask: A binary mask indicating the locations in k-space where | |
data consistency should be enforced. (torch.Tensor) | |
sens_maps: Sensitivity maps for the different coils in parallel | |
imaging. (torch.Tensor) | |
""" | |
# model-term see orange box of Fig 1 in E2E-VarNet paper! | |
# multi channel k-space -> single channel image-space | |
b, c, h, w, _ = current_kspace.shape | |
if c == 30: | |
# get kspace and inpainted kspace | |
kspace = current_kspace[:, :15, :, :, :] | |
in_kspace = current_kspace[:, 15:, :, :, :] | |
# convert to image space | |
image = sens_reduce(kspace, sens_maps) | |
in_image = sens_reduce(in_kspace, sens_maps) | |
# concatenate both onto each other | |
reduced_image = torch.cat([image, in_image], dim=1) | |
else: | |
reduced_image = sens_reduce(current_kspace, sens_maps) | |
# single channel image-space | |
refined_image = self.model(reduced_image) | |
# single channel image-space -> multi channel k-space | |
model_term = sens_expand(refined_image, sens_maps) | |
# only use first 15 channels (masked_kspace) in the update | |
# current_kspace = current_kspace[:, :15, :, :, :] | |
if not use_dc_term: | |
return current_kspace - model_term | |
""" | |
Soft data consistency term: | |
- Calculates the difference between current k-space and reference k-space where the mask is true. | |
- Multiplies this difference by the data consistency weight. | |
""" | |
# dc_term: see green box of Fig 1 in E2E-VarNet paper! | |
zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) | |
soft_dc = ( | |
torch.where(mask, current_kspace - ref_kspace, zero) | |
* self.dc_weight | |
) | |
return current_kspace - soft_dc - model_term | |
class NOVarnet_no_KNO(nn.Module): | |
""" | |
Neural Operator model for MRI reconstruction. | |
Uses a variational architecture (iterative updates) with a learned sensitivity | |
model. All operations are resolution invariant employing neural operator | |
modules (GNO, UDNO). | |
""" | |
def __init__( | |
self, | |
num_cascades: int = 12, | |
sens_chans: int = 8, | |
sens_pools: int = 4, | |
chans: int = 18, | |
pools: int = 4, | |
gno_chans: int = 16, | |
gno_pools: int = 4, | |
gno_radius_cutoff: float = 0.02, | |
gno_kernel_shape: Tuple[int, int] = (6, 7), | |
radius_cutoff: float = 0.01, | |
kernel_shape: Tuple[int, int] = (3, 4), | |
in_shape: Tuple[int, int] = (640, 320), | |
mask_center: bool = True, | |
use_dc_term: bool = True, | |
reduction_method: Literal["batch", "rss"] = "rss", | |
skip_method: Literal["replace", "add", "add_inv", "concat"] = "add", | |
): | |
""" | |
Parameters | |
---------- | |
num_cascades : int | |
Number of cascades (i.e., layers) for variational network. | |
sens_chans : int | |
Number of channels for sensitivity map U-Net. | |
sens_pools : int | |
Number of downsampling and upsampling layers for sensitivity map U-Net. | |
chans : int | |
Number of channels for cascade U-Net. | |
pools : int | |
Number of downsampling and upsampling layers for cascade U-Net. | |
mask_center : bool | |
Whether to mask center of k-space for sensitivity map calculation. | |
use_dc_term : bool | |
Whether to use the data consistency term. | |
reduction_method : "batch" or "rss" | |
Method for reducing sensitivity maps to single channel. | |
"batch" reduces to single channel by stacking channels. | |
"rss" reduces to single channel by root sum of squares. | |
skip_method : "replace" or "add" or "add_inv" or "concat" | |
"replace" replaces the input with the output of the GNO | |
"add" adds the output of the GNO to the input | |
"add_inv" adds the output of the GNO to the input (only where samples are missing) | |
"concat" concatenates the output of the GNO to the input | |
""" | |
super().__init__() | |
self.sens_net = SensitivityModel( | |
sens_chans, | |
sens_pools, | |
radius_cutoff, | |
in_shape, | |
kernel_shape, | |
mask_center=mask_center, | |
) | |
# self.gno = NormUDNO( | |
# gno_chans, | |
# gno_pools, | |
# in_shape=in_shape, | |
# radius_cutoff=radius_cutoff, | |
# kernel_shape=kernel_shape, | |
# # radius_cutoff=gno_radius_cutoff, | |
# # kernel_shape=gno_kernel_shape, | |
# in_chans=2, | |
# out_chans=2, | |
# ) | |
self.cascades = nn.ModuleList( | |
[ | |
VarNetBlock( | |
NormUDNO( | |
chans, | |
pools, | |
radius_cutoff, | |
in_shape, | |
kernel_shape, | |
in_chans=( | |
4 | |
if skip_method == "concat" and cascade_idx == 0 | |
else 2 | |
), | |
out_chans=2, | |
) | |
) | |
for cascade_idx in range(num_cascades) | |
] | |
) | |
self.use_dc_term = use_dc_term | |
self.reduction_method = reduction_method # not used anywhere anymore | |
self.skip_method = skip_method # not used anywhere anymore | |
def forward( | |
self, | |
masked_kspace: torch.Tensor, | |
mask: torch.Tensor, | |
num_low_frequencies: Optional[int] = None, | |
) -> torch.Tensor: | |
# (B, C, X, Y, 2) | |
sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) | |
kspace_pred = masked_kspace.clone() | |
# iterative update | |
for cascade in self.cascades: | |
kspace_pred = cascade( | |
kspace_pred, masked_kspace, mask, sens_maps, self.use_dc_term | |
) | |
spatial_pred = fastmri.ifft2c(kspace_pred) | |
spatial_pred_abs = fastmri.complex_abs(spatial_pred) | |
combined_spatial = fastmri.rss(spatial_pred_abs, dim=1) | |
return combined_spatial | |
if __name__ == "__main__": | |
ds = SliceDatasetLMDB( | |
"knee", | |
partition="train", | |
mask_fns=None, # type: ignore | |
complex=False, | |
sample_rate=0.5, | |
crop_shape=(320, 320), | |
coils=15, | |
) | |
sample: SliceSample = ds[0] | |
kspace = sample.masked_kspace | |
target = sample.target | |
model = NOVarnet_no_KNO(1) | |
res = model.forward(sample.masked_kspace.unsqueeze(0), sample.mask.unsqueeze(0), torch.tensor(sample.num_low_frequencies).unsqueeze(0)) |