Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import xml.etree.ElementTree as etree | |
from pathlib import Path | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
List, | |
Literal, | |
NamedTuple, | |
Optional, | |
Sequence, | |
Tuple, | |
Union, | |
) | |
import h5py | |
import lmdb | |
import numpy as np | |
import torch | |
import yaml | |
import sigpy as sp | |
import pandas as pd | |
import fastmri | |
import fastmri.transforms as T | |
class RawSample(NamedTuple): | |
fname: Path | |
slice_num: int | |
metadata: Dict[str, Any] | |
class SliceSample(NamedTuple): | |
masked_kspace: torch.Tensor | |
mask: torch.Tensor | |
num_low_frequencies: int | |
target: torch.Tensor | |
max_value: float | |
# attrs: Dict[str, Any] | |
fname: str | |
slice_num: int | |
class SliceSampleMVUE(NamedTuple): | |
masked_kspace: torch.Tensor | |
mask: torch.Tensor | |
num_low_frequencies: int | |
target: torch.Tensor | |
rss: torch.Tensor | |
max_value: float | |
# attrs: Dict[str, Any] | |
fname: str | |
slice_num: int | |
def et_query( | |
root: etree.Element, | |
qlist: Sequence[str], | |
namespace: str = "http://www.ismrm.org/ISMRMRD", | |
) -> str: | |
""" | |
Query an XML document using ElementTree. | |
This function allows querying an XML document by specifying a root and a list of nested queries. | |
It supports optional XML namespaces. | |
Parameters | |
---------- | |
root : ElementTree.Element | |
The root element of the XML to search through. | |
qlist : list of str | |
A list of strings for nested searches, e.g., ["Encoding", "matrixSize"]. | |
namespace : str, optional | |
An optional XML namespace to prepend to the query (default is None). | |
Returns | |
------- | |
str | |
The retrieved data as a string. | |
""" | |
s = "." | |
prefix = "ismrmrd_namespace" | |
ns = {prefix: namespace} | |
for el in qlist: | |
s = s + f"//{prefix}:{el}" | |
value = root.find(s, ns) | |
if value is None: | |
raise RuntimeError("Element not found") | |
return str(value.text) | |
class SliceDataset(torch.utils.data.Dataset): | |
""" | |
A simplified PyTorch Dataset that provides access to multicoil MR image | |
slices from the fastMRI dataset. | |
""" | |
def __init__( | |
self, | |
# root: Optional[Path | str], | |
body_part: Literal["knee", "brain"], | |
partition: Literal["train", "val", "test"], | |
mask_fns: Optional[List[Callable]] = None, | |
sample_rate: float = 1.0, | |
complex: bool = False, | |
crop_shape: Tuple[int, int] = (320, 320), | |
slug: str = "", | |
contrast: Optional[Literal["T1", "T2"]] = None, | |
coils: Optional[int] = None, | |
): | |
""" | |
Initializes the fastMRI multi-coil challenge dataset. | |
Samples are individual 2D slices taken from k-space volume data. | |
Parameters | |
---------- | |
body_part : {'knee', 'brain'} | |
The body part to analyze. | |
partition : {'train', 'val', 'test'} | |
The data partition type. | |
mask_fns : list of callable, optional | |
A list of masking functions to apply to samples. | |
If multiple are given, a mask is randomly chosen for each sample. | |
sample_rate : float, optional | |
Fraction of data to sample, by default 1.0. | |
complex : bool, optional | |
Whether the $k$-space data should return complex-valued, by default False. | |
If True, kspace values will be complex. | |
If False, kspace values will be real (shape, 2). | |
crop_shape : tuple of two ints, optional | |
The shape to center crop the k-space data, by default (320, 320). | |
slug : string | |
dataset slug name | |
contrast : {'T1', 'T2'} | |
If partition is brain, the contrast of images to use. | |
""" | |
with open("fastmri.yaml", "r") as file: | |
config = yaml.safe_load(file) | |
self.contrast = contrast | |
self.slug = slug | |
self.partition = partition | |
self.body_part = body_part | |
self.root = ( | |
Path(config.get(f"{body_part}_path")) / f"multicoil_{partition}" | |
) | |
self.mask_fns = mask_fns | |
self.sample_rate = sample_rate | |
self.raw_samples: List[RawSample] = self._load_samples() | |
self.complex = complex | |
self.crop_shape = crop_shape | |
self.coils = coils | |
def _load_samples(self): | |
# Gather all files in the root directory | |
if self.body_part == "brain" and self.contrast: | |
files = list(self.root.glob(f"*{self.contrast}*.h5")) | |
else: | |
files = list(self.root.glob("*.h5")) | |
raw_samples = [] | |
# Load and process metadata from each file | |
for fname in sorted(files): | |
with h5py.File(fname, "r") as hf: | |
metadata, num_slices = self._retrieve_metadata(fname) | |
# Collect samples for each slice, discard first c slices, and last c slices | |
c = 6 | |
for slice_num in range(num_slices): | |
if c <= slice_num <= num_slices - c - 1: | |
raw_samples.append( | |
RawSample(fname, slice_num, metadata) | |
) | |
# Subsample if desired | |
if self.sample_rate < 1.0: | |
raw_samples = random.sample( | |
raw_samples, int(len(raw_samples) * self.sample_rate) | |
) | |
return raw_samples | |
def _retrieve_metadata(self, fname): | |
with h5py.File(fname, "r") as hf: | |
et_root = etree.fromstring(hf["ismrmrd_header"][()]) | |
enc = ["encoding", "encodedSpace", "matrixSize"] | |
enc_size = ( | |
int(et_query(et_root, enc + ["x"])), | |
int(et_query(et_root, enc + ["y"])), | |
int(et_query(et_root, enc + ["z"])), | |
) | |
rec = ["encoding", "reconSpace", "matrixSize"] | |
recon_size = ( | |
int(et_query(et_root, rec + ["x"])), | |
int(et_query(et_root, rec + ["y"])), | |
int(et_query(et_root, rec + ["z"])), | |
) | |
lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] | |
enc_limits_center = int(et_query(et_root, lims + ["center"])) | |
enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 | |
padding_left = enc_size[1] // 2 - enc_limits_center | |
padding_right = padding_left + enc_limits_max | |
num_slices = hf["kspace"].shape[0] | |
metadata = { | |
"padding_left": padding_left, | |
"padding_right": padding_right, | |
"encoding_size": enc_size, | |
"recon_size": recon_size, | |
**hf.attrs, | |
} | |
return metadata, num_slices | |
def __len__(self): | |
return len(self.raw_samples) | |
def __getitem__(self, idx) -> SliceSample: | |
try: | |
raw_sample: RawSample = self.raw_samples[idx] | |
fname, slice_num, metadata = raw_sample | |
# load kspace and target | |
with h5py.File(fname, "r") as hf: | |
kspace = torch.tensor(hf["kspace"][()][slice_num]) | |
if not self.complex: | |
kspace = torch.view_as_real(kspace) | |
if self.coils: | |
if kspace.shape[0] < self.coils: | |
return None | |
kspace = kspace[: self.coils, :, :, :] | |
target_key = ( | |
"reconstruction_rss" | |
if self.partition in ["train", "val"] | |
else "reconstruction_esc" | |
) | |
target = hf.get(target_key, None) | |
if target is not None: | |
target = torch.tensor(target[()][slice_num]) | |
if self.body_part == "brain": | |
target = T.center_crop(target, self.crop_shape) | |
# center crop to enable collating for batching | |
if self.complex: | |
# if complex, crop across dims: -2 and -1 (last 2) | |
raise NotImplementedError("Not implemented for complex native") | |
else: | |
# crop in image space, to not lose high-frequency information | |
image = fastmri.ifft2c(kspace) | |
image_cropped = T.complex_center_crop(image, self.crop_shape) | |
kspace = fastmri.fft2c(image_cropped) | |
# apply transform mask if there is one | |
if self.mask_fns: | |
# choose a random mask | |
mask_fn = random.choice(self.mask_fns) | |
kspace, mask, num_low_frequencies = T.apply_mask( | |
kspace, | |
mask_fn, | |
# seed=seed, | |
) | |
mask = mask.bool() | |
else: | |
mask = torch.ones_like(kspace, dtype=torch.bool) | |
num_low_frequencies = 0 | |
sample = SliceSample( | |
kspace, | |
mask, | |
num_low_frequencies, | |
target, | |
metadata["max"], | |
fname.name, | |
slice_num, | |
) | |
return sample | |
except: | |
return None | |
class SliceDatasetLMDB(torch.utils.data.Dataset): | |
""" | |
A simplified PyTorch Dataset that provides access to multicoil MR image | |
slices from the fastMRI dataset. Loads from LMDB saved samples. | |
""" | |
def __init__( | |
self, | |
body_part: Literal["knee", "brain"], | |
partition: Literal["train", "val", "test"], | |
root: Optional[Path | str] = None, | |
mask_fns: Optional[List[Callable]] = None, | |
sample_rate: float = 1.0, | |
complex: bool = False, | |
crop_shape: Tuple[int, int] = (320, 320), | |
slug: str = "", | |
coils: int = 15, | |
): | |
""" | |
Initializes the fastMRI multi-coil challenge dataset. | |
Samples are individual 2D slices taken from k-space volume data. | |
Parameters | |
---------- | |
body_part : {'knee', 'brain'} | |
The body part to analyze. | |
root : Path or str, optional | |
Root to lmdb dataset. If not provided, the root is automatically | |
loaded directly from fastmri.yaml config | |
partition : {'train', 'val', 'test'} | |
The data partition type. | |
mask_fns : list of callable, optional | |
A list of masking functions to apply to samples. | |
If multiple are given, a mask is randomly chosen for each sample. | |
sample_rate : float, optional | |
Fraction of data to sample, by default 1.0. | |
complex : bool, optional | |
Whether the $k$-space data should return complex-valued, by default False. | |
If True, kspace values will be complex. | |
If False, kspace values will be real (shape, 2). | |
crop_shape : tuple of two ints, optional | |
The shape to center crop the k-space data, by default (320, 320). | |
slug : string | |
dataset slug name | |
""" | |
# set attrs | |
self.coils = coils | |
self.slug = slug | |
self.partition = partition | |
self.mask_fns = mask_fns | |
self.sample_rate = sample_rate | |
self.complex = complex | |
self.crop_shape = crop_shape | |
# load lmdb info | |
if root: | |
if isinstance(root, str): | |
root = Path(root) | |
assert root.exists(), "Provided root doesn't exist." | |
self.root = root | |
else: | |
with open("fastmri.yaml", "r") as file: | |
config = yaml.safe_load(file) | |
self.root = Path(config["lmdb"][f"{body_part}_{partition}_path"]) | |
self.meta = np.load(self.root / "meta.npy") | |
self.kspace_env = lmdb.open( | |
str(self.root / "kspace"), | |
readonly=True, | |
lock=False, | |
create=False, | |
) | |
self.kspace_txn = self.kspace_env.begin(write=False) | |
self.rss_env = lmdb.open( | |
str(self.root / "rss"), | |
readonly=True, | |
lock=False, | |
create=False, | |
) | |
self.rss_txn = self.rss_env.begin(write=False) | |
self.length = self.kspace_txn.stat()["entries"] | |
def __len__(self): | |
return int(self.sample_rate * self.length) | |
def __getitem__(self, idx) -> SliceSample: | |
idx_key = str(idx).encode("utf-8") | |
# load sample data | |
kspace = torch.from_numpy( | |
np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32) | |
.reshape(self.coils, 320, 320, 2) | |
.copy() | |
) | |
rss = torch.from_numpy( | |
np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32) | |
.reshape(320, 320) | |
.copy() | |
) | |
# crop in image space, to not lose high-frequency information | |
if self.crop_shape and self.crop_shape != (320, 320): | |
image = fastmri.ifft2c(kspace) | |
image_cropped = T.complex_center_crop(image, self.crop_shape) | |
kspace = fastmri.fft2c(image_cropped) | |
rss = T.center_crop(rss, self.crop_shape) | |
# load and apply mask | |
if self.mask_fns: | |
# choose a random mask | |
mask_fn = random.choice(self.mask_fns) | |
kspace, mask, num_low_frequencies = T.apply_mask( | |
kspace, | |
mask_fn, # type: ignore | |
) | |
mask = mask.bool() | |
else: | |
mask = torch.ones_like(kspace, dtype=torch.bool) | |
num_low_frequencies = 0 | |
# load metadata | |
fname, slice_num, max_value = self.meta[idx] | |
fname = str(fname) | |
slice_num = int(slice_num) | |
max_value = float(max_value) | |
return SliceSample( | |
kspace, | |
mask, | |
num_low_frequencies, | |
rss, | |
max_value, | |
fname, | |
slice_num, | |
) | |
class SliceDatasetLMDB_MVUE(torch.utils.data.Dataset): | |
""" | |
Loads from LMDB brain saved samples. | |
Modified to have MVUE targets | |
""" | |
def __init__( | |
self, | |
root: Path | str, | |
mask_fns: Optional[List[Callable]] = None, | |
sample_rate: float = 1.0, | |
crop_shape: Tuple[int, int] = (320, 320), | |
slug: str = "", | |
coils: int = 15, | |
): | |
# set attrs | |
self.coils = coils | |
self.slug = slug | |
self.mask_fns = mask_fns | |
self.sample_rate = sample_rate | |
self.complex = complex | |
self.crop_shape = crop_shape | |
# load lmdb info | |
if isinstance(root, str): | |
root = Path(root) | |
assert root.exists(), "Provided root doesn't exist." | |
self.root = root | |
self.meta = np.load(self.root / "meta.npy") | |
self.mapping = pd.read_csv("brain_mvue_map.csv") | |
self.kspace_env = lmdb.open( | |
str(self.root / "kspace"), | |
readonly=True, | |
lock=False, | |
create=False, | |
) | |
self.kspace_txn = self.kspace_env.begin(write=False) | |
self.rss_env = lmdb.open( | |
str(self.root / "rss"), | |
readonly=True, | |
lock=False, | |
create=False, | |
) | |
self.rss_txn = self.rss_env.begin(write=False) | |
# ray mvue dataset | |
self.mvue_env = lmdb.open( | |
str("/pscratch/sd/p/peterwg/datasets/raytemp"), | |
readonly=True, | |
lock=False, | |
create=False, | |
) | |
self.mvue_txn = self.mvue_env.begin(write=False) | |
self.length = len(self.mapping) | |
# self.length = self.kspace_txn.stat()["entries"] | |
def __len__(self): | |
return int(self.sample_rate * self.length) | |
def __getitem__(self, idx) -> SliceSampleMVUE: | |
# ray's index: 0-n | |
ray_idx = idx | |
# my index: lookup(ray index) | |
idx = int(self.mapping.iloc[ray_idx].my_index) | |
ray_idx_key = str(ray_idx).encode("utf-8") | |
idx_key = str(idx).encode("utf-8") | |
# load sample data | |
kspace = torch.from_numpy( | |
np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32) | |
.reshape(self.coils, 320, 320, 2) | |
.copy() | |
) | |
# mvue_target = np.sum( | |
# sp.ifft(kspace, axes=(-1, -2)) * np.conj(s_maps), axis=1 | |
# ) / np.sqrt(np.sum(np.square(np.abs(s_maps)), axis=1)) | |
rss = torch.from_numpy( | |
np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32) | |
.reshape(320, 320) | |
.copy() | |
) | |
# load mvue from ray dataset | |
mvue = torch.from_numpy( | |
np.frombuffer(self.mvue_txn.get(ray_idx_key), dtype=np.complex64) | |
.reshape(320, 320) | |
.copy() | |
) | |
mvue = torch.abs(mvue) | |
# crop in image space, to not lose high-frequency information | |
if self.crop_shape and self.crop_shape != (320, 320): | |
image = fastmri.ifft2c(kspace) | |
image_cropped = T.complex_center_crop(image, self.crop_shape) | |
kspace = fastmri.fft2c(image_cropped) | |
rss = T.center_crop(rss, self.crop_shape) | |
# load and apply mask | |
if self.mask_fns: | |
# choose a random mask | |
mask_fn = random.choice(self.mask_fns) | |
kspace, mask, num_low_frequencies = T.apply_mask( | |
kspace, | |
mask_fn, # type: ignore | |
) | |
mask = mask.bool() | |
else: | |
mask = torch.ones_like(kspace, dtype=torch.bool) | |
num_low_frequencies = 0 | |
# load metadata | |
fname, slice_num, max_value = self.meta[idx] | |
fname = str(fname) | |
slice_num = int(slice_num) | |
max_value = float(max_value) | |
return SliceSampleMVUE( | |
kspace, | |
mask, | |
num_low_frequencies, | |
mvue, | |
rss, | |
max_value, | |
fname, | |
slice_num, | |
) | |
# d = SliceDatasetLMDB("knee", "val", None, 1, True, (320, 320), "testdataset") | |
# print(len(d)) | |
# breakpoint() | |
# ds = SuperSliceDatasetLMDB( | |
# "brain", # body_part | |
# "val", # partition | |
# None, # root | |
# None, # mask_fns | |
# 1.0, # sample_rate | |
# True, # complex | |
# (320, 320), # crop_shape | |
# "test-superres", # slug | |
# coils=16, # coils | |
# ) | |
# breakpoint() | |
# d = SliceDataset("brain", "train", None, contrast="T2") | |
# # TESTING MVUE | |
# d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/mri_brain_train_lmdb", coils=16) | |
# x = d[0] | |
# d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/raytemp/", coils=16) |