Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) Facebook, Inc. and its affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import os | |
from typing import Dict, Optional, Sequence, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.distributions as D | |
from networkx import center | |
from sigpy.mri import poisson, radial, spiral | |
class MaskFunc: | |
""" | |
An object for GRAPPA-style sampling masks. | |
This crates a sampling mask that densely samples the center while | |
subsampling outer k-space regions based on the undersampling factor. | |
When called, ``MaskFunc`` uses internal functions create mask by 1) | |
creating a mask for the k-space center, 2) create a mask outside of the | |
k-space center, and 3) combining them into a total mask. The internals are | |
handled by ``sample_mask``, which calls ``calculate_center_mask`` for (1) | |
and ``calculate_acceleration_mask`` for (2). The combination is executed | |
in the ``MaskFunc`` ``__call__`` function. | |
If you would like to implement a new mask, simply subclass ``MaskFunc`` | |
and overwrite the ``sample_mask`` logic. See examples in ``RandomMaskFunc`` | |
and ``EquispacedMaskFunc``. | |
""" | |
def __init__( | |
self, | |
center_fractions: Sequence[float], | |
accelerations: Sequence[int], | |
allow_any_combination: bool = False, | |
seed: Optional[int] = None, | |
): | |
""" | |
Args: | |
center_fractions: Fraction of low-frequency columns to be retained. | |
If multiple values are provided, then one of these numbers is | |
chosen uniformly each time. | |
accelerations: Amount of under-sampling. This should have the same | |
length as center_fractions. If multiple values are provided, | |
then one of these is chosen uniformly each time. | |
allow_any_combination: Whether to allow cross combinations of | |
elements from ``center_fractions`` and ``accelerations``. | |
seed: Seed for starting the internal random number generator of the | |
``MaskFunc``. | |
""" | |
if ( | |
len(center_fractions) != len(accelerations) | |
and not allow_any_combination | |
): | |
raise ValueError( | |
"Number of center fractions should match number of" | |
" accelerations if allow_any_combination is False." | |
) | |
self.center_fractions = center_fractions | |
self.accelerations = accelerations | |
self.allow_any_combination = allow_any_combination | |
self.rng = np.random.RandomState(seed) | |
def __call__( | |
self, | |
shape: Sequence[int], | |
offset: Optional[int] = None, | |
seed: Optional[Union[int, Tuple[int, ...]]] = None, | |
) -> Tuple[torch.Tensor, int]: | |
""" | |
Sample and return a k-space mask. | |
Args: | |
shape: Shape of k-space. | |
offset: Offset from 0 to begin mask (for equispaced masks). If no | |
offset is given, then one is selected randomly. | |
seed: Seed for random number generator for reproducibility. | |
Returns: | |
A 2-tuple containing 1) the k-space mask and 2) the number of | |
center frequency lines. | |
""" | |
if len(shape) < 3: | |
raise ValueError("Shape should have 3 or more dimensions") | |
center_mask, accel_mask, num_low_frequencies = self.sample_mask( | |
shape, offset | |
) | |
# combine masks together | |
return torch.max(center_mask, accel_mask), num_low_frequencies | |
def sample_mask( | |
self, | |
shape: Sequence[int], | |
offset: Optional[int], | |
) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
""" | |
Sample a new k-space mask. | |
This function samples and returns two components of a k-space mask: 1) | |
the center mask (e.g., for sensitivity map calculation) and 2) the | |
acceleration mask (for the edge of k-space). Both of these masks, as | |
well as the integer of low frequency samples, are returned. | |
Args: | |
shape: Shape of the k-space to subsample. | |
offset: Offset from 0 to begin mask (for equispaced masks). | |
Returns: | |
A 3-tuple contaiing 1) the mask for the center of k-space, 2) the | |
mask for the high frequencies of k-space, and 3) the integer count | |
of low frequency samples. | |
""" | |
num_cols = shape[-2] | |
center_fraction, acceleration = self.choose_acceleration() | |
num_low_frequencies = round(num_cols * center_fraction) | |
center_mask = self.reshape_mask( | |
self.calculate_center_mask(shape, num_low_frequencies), shape | |
) | |
acceleration_mask = self.reshape_mask( | |
self.calculate_acceleration_mask( | |
num_cols, acceleration, offset, num_low_frequencies | |
), | |
shape, | |
) | |
return center_mask, acceleration_mask, num_low_frequencies | |
def reshape_mask( | |
self, mask: torch.Tensor, shape: Sequence[int] | |
) -> torch.Tensor: | |
"""Reshape mask to desired output shape.""" | |
if len(mask.shape) == 1: | |
mask = torch.tensor(mask) | |
mask_num_freqs = len(mask) | |
mask = mask.reshape(1, 1, mask_num_freqs, 1) | |
mask = mask.expand(shape) | |
return mask.expand(shape) | |
def reshape_mask_old( | |
self, mask: np.ndarray, shape: Sequence[int] | |
) -> torch.Tensor: | |
"""Reshape mask to desired output shape.""" | |
num_cols = shape[-2] | |
mask_shape = [1 for s in shape] | |
mask_shape[-2] = num_cols | |
return torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) | |
def calculate_acceleration_mask( | |
self, | |
num_cols: int, | |
acceleration: int, | |
offset: Optional[int], | |
num_low_frequencies: int, | |
) -> np.ndarray: | |
""" | |
Produce mask for non-central acceleration lines. | |
Args: | |
num_cols: Number of columns of k-space (2D subsampling). | |
acceleration: Desired acceleration rate. | |
offset: Offset from 0 to begin masking (for equispaced masks). | |
num_low_frequencies: Integer count of low-frequency lines sampled. | |
Returns: | |
A mask for the high spatial frequencies of k-space. | |
""" | |
raise NotImplementedError | |
def calculate_center_mask( | |
self, shape: Sequence[int], num_low_freqs: int | |
) -> np.ndarray: | |
""" | |
Build center mask based on number of low frequencies. | |
Args: | |
shape: Shape of k-space to mask. | |
num_low_freqs: Number of low-frequency lines to sample. | |
Returns: | |
A mask for hte low spatial frequencies of k-space. | |
""" | |
num_cols = shape[-2] | |
mask = np.zeros(num_cols, dtype=np.float32) | |
pad = (num_cols - num_low_freqs + 1) // 2 | |
mask[pad : pad + num_low_freqs] = 1 | |
assert mask.sum() == num_low_freqs | |
return mask | |
def choose_acceleration(self): | |
"""Choose acceleration based on class parameters.""" | |
if self.allow_any_combination: | |
return self.rng.choice(self.center_fractions), self.rng.choice( | |
self.accelerations | |
) | |
else: | |
choice = self.rng.randint(len(self.center_fractions)) | |
return self.center_fractions[choice], self.accelerations[choice] | |
class RandomMaskFunc(MaskFunc): | |
""" | |
Creates a random sub-sampling mask of a given shape. | |
The mask selects a subset of columns from the input k-space data. If the | |
k-space data has N columns, the mask picks out: | |
1. N_low_freqs = (N * center_fraction) columns in the center | |
corresponding to low-frequencies. | |
2. The other columns are selected uniformly at random with a | |
probability equal to: prob = (N / acceleration - N_low_freqs) / | |
(N - N_low_freqs). This ensures that the expected number of columns | |
selected is equal to (N / acceleration). | |
It is possible to use multiple center_fractions and accelerations, in which | |
case one possible (center_fraction, acceleration) is chosen uniformly at | |
random each time the ``RandomMaskFunc`` object is called. | |
For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], | |
then there is a 50% probability that 4-fold acceleration with 8% center | |
fraction is selected and a 50% probability that 8-fold acceleration with 4% | |
center fraction is selected. | |
""" | |
def calculate_acceleration_mask( | |
self, | |
num_cols: int, | |
acceleration: int, | |
offset: Optional[int], | |
num_low_frequencies: int, | |
) -> np.ndarray: | |
prob = (num_cols / acceleration - num_low_frequencies) / ( | |
num_cols - num_low_frequencies | |
) | |
return self.rng.uniform(size=num_cols) < prob | |
class EquiSpacedMaskFunc(MaskFunc): | |
""" | |
Sample data with equally-spaced k-space lines. | |
The lines are spaced exactly evenly, as is done in standard GRAPPA-style | |
acquisitions. This means that with a densely-sampled center, | |
``acceleration`` will be greater than the true acceleration rate. | |
""" | |
def calculate_acceleration_mask( | |
self, | |
num_cols: int, | |
acceleration: int, | |
offset: Optional[int], | |
num_low_frequencies: int, | |
) -> np.ndarray: | |
""" | |
Produce mask for non-central acceleration lines. | |
Args: | |
num_cols: Number of columns of k-space (2D subsampling). | |
acceleration: Desired acceleration rate. | |
offset: Offset from 0 to begin masking. If no offset is specified, | |
then one is selected randomly. | |
num_low_frequencies: Not used. | |
Returns: | |
A mask for the high spatial frequencies of k-space. | |
""" | |
if offset is None: | |
offset = self.rng.randint(0, high=round(acceleration)) | |
mask = np.zeros(num_cols, dtype=np.float32) | |
mask[offset::acceleration] = 1 | |
return mask | |
class EquispacedMaskFractionFunc(MaskFunc): | |
""" | |
Equispaced mask with approximate acceleration matching. | |
The mask selects a subset of columns from the input k-space data. If the | |
k-space data has N columns, the mask picks out: | |
1. N_low_freqs = (N * center_fraction) columns in the center | |
corresponding to low-frequencies. | |
2. The other columns are selected with equal spacing at a proportion | |
that reaches the desired acceleration rate taking into consideration | |
the number of low frequencies. This ensures that the expected number | |
of columns selected is equal to (N / acceleration) | |
It is possible to use multiple center_fractions and accelerations, in which | |
case one possible (center_fraction, acceleration) is chosen uniformly at | |
random each time the EquispacedMaskFunc object is called. | |
Note that this function may not give equispaced samples (documented in | |
https://github.com/facebookresearch/fastMRI/issues/54), which will require | |
modifications to standard GRAPPA approaches. Nonetheless, this aspect of | |
the function has been preserved to match the public multicoil data. | |
""" | |
def calculate_acceleration_mask( | |
self, | |
num_cols: int, | |
acceleration: int, | |
offset: Optional[int], | |
num_low_frequencies: int, | |
) -> np.ndarray: | |
""" | |
Produce mask for non-central acceleration lines. | |
Args: | |
num_cols: Number of columns of k-space (2D subsampling). | |
acceleration: Desired acceleration rate. | |
offset: Offset from 0 to begin masking. If no offset is specified, | |
then one is selected randomly. | |
num_low_frequencies: Number of low frequencies. Used to adjust mask | |
to exactly match the target acceleration. | |
Returns: | |
A mask for the high spatial frequencies of k-space. | |
""" | |
# determine acceleration rate by adjusting for the number of low frequencies | |
adjusted_accel = (acceleration * (num_low_frequencies - num_cols)) / ( | |
num_low_frequencies * acceleration - num_cols | |
) | |
if offset is None: | |
offset = self.rng.randint(0, high=round(adjusted_accel)) | |
mask = np.zeros(num_cols, dtype=np.float32) | |
accel_samples = np.arange(offset, num_cols - 1, adjusted_accel) | |
accel_samples = np.around(accel_samples).astype(np.uint) | |
mask[accel_samples] = 1.0 | |
return mask | |
class MagicMaskFunc(MaskFunc): | |
""" | |
Masking function for exploiting conjugate symmetry via offset-sampling. | |
This function applies the mask described in the following paper: | |
Defazio, A. (2019). Offset Sampling Improves Deep Learning based | |
Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint, | |
arXiv:1912.01101. | |
It is essentially an equispaced mask with an offset for the opposite site | |
of k-space. Since MRI images often exhibit approximate conjugate k-space | |
symmetry, this mask is generally more efficient than a standard equispaced | |
mask. | |
Similarly to ``EquispacedMaskFunc``, this mask will usually undereshoot the | |
target acceleration rate. | |
""" | |
def calculate_acceleration_mask( | |
self, | |
num_cols: int, | |
acceleration: int, | |
offset: Optional[int], | |
num_low_frequencies: int, | |
) -> np.ndarray: | |
""" | |
Produce mask for non-central acceleration lines. | |
Args: | |
num_cols: Number of columns of k-space (2D subsampling). | |
acceleration: Desired acceleration rate. | |
offset: Offset from 0 to begin masking. If no offset is specified, | |
then one is selected randomly. | |
num_low_frequencies: Not used. | |
Returns: | |
A mask for the high spatial frequencies of k-space. | |
""" | |
if offset is None: | |
offset = self.rng.randint(0, high=acceleration) | |
if offset % 2 == 0: | |
offset_pos = offset + 1 | |
offset_neg = offset + 2 | |
else: | |
offset_pos = offset - 1 + 3 | |
offset_neg = offset - 1 + 0 | |
poslen = (num_cols + 1) // 2 | |
neglen = num_cols - (num_cols + 1) // 2 | |
mask_positive = np.zeros(poslen, dtype=np.float32) | |
mask_negative = np.zeros(neglen, dtype=np.float32) | |
mask_positive[offset_pos::acceleration] = 1 | |
mask_negative[offset_neg::acceleration] = 1 | |
mask_negative = np.flip(mask_negative) | |
mask = np.concatenate((mask_positive, mask_negative)) | |
return np.fft.fftshift(mask) # shift mask and return | |
class MagicMaskFractionFunc(MagicMaskFunc): | |
""" | |
Masking function for exploiting conjugate symmetry via offset-sampling. | |
This function applies the mask described in the following paper: | |
Defazio, A. (2019). Offset Sampling Improves Deep Learning based | |
Accelerated MRI Reconstructions by Exploiting Symmetry. arXiv preprint, | |
arXiv:1912.01101. | |
It is essentially an equispaced mask with an offset for the opposite site | |
of k-space. Since MRI images often exhibit approximate conjugate k-space | |
symmetry, this mask is generally more efficient than a standard equispaced | |
mask. | |
Similarly to ``EquispacedMaskFractionFunc``, this method exactly matches | |
the target acceleration by adjusting the offsets. | |
""" | |
def sample_mask( | |
self, | |
shape: Sequence[int], | |
offset: Optional[int], | |
) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
""" | |
Sample a new k-space mask. | |
This function samples and returns two components of a k-space mask: 1) | |
the center mask (e.g., for sensitivity map calculation) and 2) the | |
acceleration mask (for the edge of k-space). Both of these masks, as | |
well as the integer of low frequency samples, are returned. | |
Args: | |
shape: Shape of the k-space to subsample. | |
offset: Offset from 0 to begin mask (for equispaced masks). | |
Returns: | |
A 3-tuple contaiing 1) the mask for the center of k-space, 2) the | |
mask for the high frequencies of k-space, and 3) the integer count | |
of low frequency samples. | |
""" | |
num_cols = shape[-2] | |
fraction_low_freqs, acceleration = self.choose_acceleration() | |
num_cols = shape[-2] | |
num_low_frequencies = round(num_cols * fraction_low_freqs) | |
# bound the number of low frequencies between 1 and target columns | |
target_columns_to_sample = round(num_cols / acceleration) | |
num_low_frequencies = max( | |
min(num_low_frequencies, target_columns_to_sample), 1 | |
) | |
# adjust acceleration rate based on target acceleration. | |
adjusted_target_columns_to_sample = ( | |
target_columns_to_sample - num_low_frequencies | |
) | |
adjusted_acceleration = 0 | |
if adjusted_target_columns_to_sample > 0: | |
adjusted_acceleration = round( | |
num_cols / adjusted_target_columns_to_sample | |
) | |
center_mask = self.reshape_mask( | |
self.calculate_center_mask(shape, num_low_frequencies), shape | |
) | |
accel_mask = self.reshape_mask( | |
self.calculate_acceleration_mask( | |
num_cols, adjusted_acceleration, offset, num_low_frequencies | |
), | |
shape, | |
) | |
return center_mask, accel_mask, num_low_frequencies | |
class Gaussian2DMaskFunc(MaskFunc): | |
"""Gaussian 2D Masking | |
Args: | |
MaskFunc (_type_): _description_ | |
""" | |
def __init__( | |
self, | |
accelerations: Sequence[int], | |
stds: Sequence[float], | |
seed: Optional[int] = None, | |
): | |
"""initialize Gaussian 2D Mask | |
Args: | |
accelerations (Sequence[int]): list of acceleration factors, when | |
generating a mask, an acceleration factor from this list will be chosen | |
stds (Sequence[float]): list of torch.Normal scale (~std) to choose from | |
seed (Optional[int], optional): Seed for selecting mask parameters. Defaults to None. | |
""" | |
self.rng = np.random.RandomState(seed) | |
self.accelerations = accelerations | |
self.stds = stds | |
def __call__( | |
self, | |
shape: Sequence[int], | |
offset: Optional[int] = None, | |
seed: Optional[Union[int, Tuple[int, ...]]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
if len(shape) < 3: | |
raise ValueError("Shape should have 3 or more dimensions") | |
acceleration = self.rng.choice(self.accelerations) | |
std = self.rng.choice(self.stds) | |
x, y = shape[-3], shape[-2] | |
mean_x = x // 2 | |
mean_y = y // 2 | |
num_samples_collected = 0 | |
dist = D.Normal( | |
loc=torch.tensor([mean_x, mean_y], dtype=torch.float32), | |
scale=std, | |
) | |
N = ( | |
int(1 / acceleration * x * y) + 10000 | |
) # add constant or won't reach desired subsampling rate | |
sample_x, sample_y = ( | |
torch.zeros(N, dtype=torch.int), | |
torch.zeros(N, dtype=torch.int), | |
) | |
while num_samples_collected < N: | |
samples = dist.sample((N,)) # type: ignore | |
valid_samples = ( | |
(samples[:, 0] >= 0) | |
& (samples[:, 0] < x) | |
& (samples[:, 1] >= 0) | |
& (samples[:, 1] < y) | |
) | |
valid_x = samples[valid_samples, 0].int() | |
valid_y = samples[valid_samples, 1].int() | |
num_to_take = min(N - num_samples_collected, valid_x.size(0)) | |
sample_x[ | |
num_samples_collected : num_samples_collected + num_to_take | |
] = valid_x[:num_to_take] | |
sample_y[ | |
num_samples_collected : num_samples_collected + num_to_take | |
] = valid_y[:num_to_take] | |
num_samples_collected += num_to_take | |
mask = torch.zeros((x, y)) | |
mask[sample_x, sample_y] = 1.0 | |
# broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size | |
mask = mask.unsqueeze(-1) # (x, y, 1) | |
mask = mask.unsqueeze(0) # (1, x, y, 1) | |
mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone() | |
# num_low_freqs doesn't make sense so just return std (a number) | |
# returning None doesn't work since we can't stack for multiple batches | |
return mask, std | |
class Poisson2DMaskFunc(MaskFunc): | |
""" | |
Variable Density Poisson Disk Sampling | |
https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.poisson.html#sigpy.mri.poisson | |
""" | |
def __init__( | |
self, | |
accelerations: Sequence[int], | |
stds: None, | |
seed: Optional[int] = None, | |
use_cache: bool = True, | |
): | |
"""initialize VDPD (Poisson) mask | |
Args: | |
accelerations (Sequence[int]): list of acceleration factors to | |
choose from | |
stds: Dummy param. Do not pass value. Defaults to None. | |
seed (Optional[int], optional): Seed for selecting mask params. | |
Defaults to None. | |
""" | |
self.rng = np.random.RandomState(seed) | |
self.accelerations = accelerations | |
self.use_cache = use_cache | |
if use_cache: | |
self.cache: Dict[int, np.ndarray] = dict() | |
for acc in accelerations: | |
# assert os.path.exists( | |
# f"fastmri/poisson_cache/poisson_{acc}x.npy" | |
# ) | |
# self.cache[acc] = np.load( | |
# f"fastmri/poisson_cache/poisson_{acc}x.npy" | |
# ) | |
self.cache[acc] = np.load( | |
f"/global/homes/p/peterwg/more/medical-imaging/fastmri/poisson_cache/poisson_{acc}x.npy" | |
) | |
def __call__( | |
self, | |
shape: Sequence[int], | |
offset: Optional[int] = None, | |
seed: Optional[Union[int, Tuple[int, ...]]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
if self.use_cache: | |
acceleration = self.rng.choice(self.accelerations) | |
return torch.from_numpy(self.cache[acceleration]), 1.0 # type: ignore | |
if len(shape) < 3: | |
raise ValueError("Shape should have 3 or more dimensions") | |
acceleration = self.rng.choice(self.accelerations) | |
x, y = shape[-3], shape[-2] | |
mask = poisson(img_shape=(x, y), accel=acceleration, dtype=np.float32) | |
mask = torch.from_numpy(mask) | |
# broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size | |
mask = mask.unsqueeze(-1) # (x, y, 1e | |
mask = mask.unsqueeze(0) # (1, x, y, 1) | |
mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone() | |
# num low freqs doesn't make sense here, so we return arbitrary value 1.0 | |
return mask, 100.0 | |
class Radial2DMaskFunc(MaskFunc): | |
""" | |
Radial trajectory MRI masking method. | |
https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.radial.html#sigpy.mri.radial | |
""" | |
def __init__( | |
self, | |
accelerations: Sequence[int], | |
arms: Optional[Sequence[int]], | |
seed: Optional[int] = None, | |
): | |
""" | |
initialize Radial mask | |
Args: | |
accelerations (Sequence[int]): list of acceleration factors to | |
choose from | |
arms: Number of radial arms. | |
seed (Optional[int], optional): Seed for selecting mask params. | |
Defaults to None. | |
""" | |
self.rng = np.random.RandomState(seed) | |
self.accelerations = accelerations | |
self.arms = arms | |
def __call__( | |
self, | |
shape: Sequence[int], | |
offset: Optional[int] = None, | |
seed: Optional[Union[int, Tuple[int, ...]]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
if len(shape) < 3: | |
raise ValueError("Shape should have 3 or more dimensions") | |
acceleration = self.rng.choice(self.accelerations) | |
x, y = shape[-3], shape[-2] | |
npoints = int(x * y * (1 / acceleration)) | |
if self.arms: | |
arms = self.rng.choice(self.arms) | |
else: | |
points_per_arm = x // 3 | |
arms = npoints // points_per_arm | |
# calculate radial parameters to satisfy acceleration factor | |
ntr = arms # num radial lines | |
nro = npoints // arms # num points on each radial line | |
ndim = 2 # 2D | |
# gen trajectory w/ shape (ntr, nro, ndim) | |
traj = radial( | |
coord_shape=[ntr, nro, ndim], | |
img_shape=(x, y), | |
golden=True, | |
dtype=int, | |
) | |
mask = torch.zeros(x, y, dtype=torch.float32) | |
x_coords = traj[..., 0].flatten() + (x // 2) | |
y_coords = traj[..., 1].flatten() + (y // 2) | |
mask[x_coords, y_coords] = 1.0 | |
# broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size | |
mask = mask.unsqueeze(-1) # (x, y, 1) | |
mask = mask.unsqueeze(0) # (1, x, y, 1) | |
mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone() | |
# num low freqs doesn't make sense here, so we return arbitrary value 1.0 | |
return mask, 100.0 | |
class Spiral2DMaskFunc(MaskFunc): | |
""" | |
Spiral trajectory MRI masking method. | |
https://sigpy.readthedocs.io/en/latest/generated/sigpy.mri.spiral.html#sigpy.mri.spiral | |
""" | |
def __init__( | |
self, | |
accelerations: Sequence[int], | |
arms: Sequence[int], | |
seed: Optional[int] = None, | |
): | |
""" | |
initialize Radial mask | |
Args: | |
accelerations (Sequence[int]): list of acceleration factors to | |
choose from | |
arms: Number of radial arms. | |
seed (Optional[int], optional): Seed for selecting mask params. | |
Defaults to None. | |
""" | |
self.rng = np.random.RandomState(seed) | |
self.accelerations = accelerations | |
self.arms = arms | |
def __call__( | |
self, | |
shape: Sequence[int], | |
offset: Optional[int] = None, | |
seed: Optional[Union[int, Tuple[int, ...]]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
# TODO: implement | |
raise (NotImplementedError("Spiral2D not implemented")) | |
if len(shape) < 3: | |
raise ValueError("Shape should have 3 or more dimensions") | |
acceleration = self.rng.choice(self.accelerations) | |
arms = self.rng.choice(self.arms) | |
x, y = shape[-3], shape[-2] | |
# calculate radial parameters to satisfy acceleration factor | |
npoints = int(x * y * (1 / acceleration)) | |
# gen trajectory w/ shape (ntr, nro, ndim) | |
traj = spiral( | |
N=npoints, | |
img_shape=(x, y), | |
golden=True, | |
dtype=int, | |
) | |
mask = torch.zeros(x, y, dtype=float) | |
x_coords = traj[..., 0].flatten() + (x // 2) | |
y_coords = traj[..., 1].flatten() + (y // 2) | |
mask[x_coords, y_coords] = 1.0 | |
# broadcasting mask (x, y) --> (N, x, y, C) C=2, N=batch_size | |
mask = mask.unsqueeze(-1) # (x, y, 1) | |
mask = mask.unsqueeze(0) # (1, x, y, 1) | |
mask = mask.expand((1, mask.shape[1], mask.shape[2], 2)).clone() | |
# num low freqs doesn't make sense here, so we return arbitrary value 1.0 | |
return mask, 100.0 | |
def create_mask_for_mask_type( | |
mask_type_str: str, | |
center_fractions: Optional[Sequence], | |
accelerations: Sequence[int], | |
) -> MaskFunc: | |
""" | |
Creates a mask of the specified type. | |
Args: | |
center_fractions: What fraction of the center of k-space to include. | |
accelerations: What accelerations to apply. | |
Returns: | |
A mask func for the target mask type. | |
""" | |
if mask_type_str == "random": | |
return RandomMaskFunc(center_fractions, accelerations) | |
elif mask_type_str == "equispaced": | |
return EquiSpacedMaskFunc(center_fractions, accelerations) | |
elif mask_type_str == "equispaced_fraction": | |
return EquispacedMaskFractionFunc(center_fractions, accelerations) | |
elif mask_type_str == "magic": | |
return MagicMaskFunc(center_fractions, accelerations) | |
elif mask_type_str == "magic_fraction": | |
return MagicMaskFractionFunc(center_fractions, accelerations) | |
elif mask_type_str == "gaussian_2d": | |
return Gaussian2DMaskFunc( | |
stds=center_fractions, | |
accelerations=accelerations, | |
) | |
elif mask_type_str == "poisson_2d": | |
return Poisson2DMaskFunc( | |
accelerations=accelerations, | |
stds=None, | |
) | |
elif mask_type_str == "radial_2d": | |
return Radial2DMaskFunc( | |
accelerations=accelerations, | |
arms=( | |
[int(arm) for arm in center_fractions] | |
if center_fractions | |
else None | |
), | |
) | |
elif mask_type_str == "spiral_2d": | |
raise NotImplementedError("spiral_2d not implemented") | |
else: | |
raise ValueError(f"{mask_type_str} not supported") | |