nomri / fastmri /datasets.py
samaonline
init
1b34a12
import random
import xml.etree.ElementTree as etree
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import h5py
import lmdb
import numpy as np
import torch
import yaml
import sigpy as sp
import pandas as pd
import fastmri
import fastmri.transforms as T
class RawSample(NamedTuple):
fname: Path
slice_num: int
metadata: Dict[str, Any]
class SliceSample(NamedTuple):
masked_kspace: torch.Tensor
mask: torch.Tensor
num_low_frequencies: int
target: torch.Tensor
max_value: float
# attrs: Dict[str, Any]
fname: str
slice_num: int
class SliceSampleMVUE(NamedTuple):
masked_kspace: torch.Tensor
mask: torch.Tensor
num_low_frequencies: int
target: torch.Tensor
rss: torch.Tensor
max_value: float
# attrs: Dict[str, Any]
fname: str
slice_num: int
def et_query(
root: etree.Element,
qlist: Sequence[str],
namespace: str = "http://www.ismrm.org/ISMRMRD",
) -> str:
"""
Query an XML document using ElementTree.
This function allows querying an XML document by specifying a root and a list of nested queries.
It supports optional XML namespaces.
Parameters
----------
root : ElementTree.Element
The root element of the XML to search through.
qlist : list of str
A list of strings for nested searches, e.g., ["Encoding", "matrixSize"].
namespace : str, optional
An optional XML namespace to prepend to the query (default is None).
Returns
-------
str
The retrieved data as a string.
"""
s = "."
prefix = "ismrmrd_namespace"
ns = {prefix: namespace}
for el in qlist:
s = s + f"//{prefix}:{el}"
value = root.find(s, ns)
if value is None:
raise RuntimeError("Element not found")
return str(value.text)
class SliceDataset(torch.utils.data.Dataset):
"""
A simplified PyTorch Dataset that provides access to multicoil MR image
slices from the fastMRI dataset.
"""
def __init__(
self,
# root: Optional[Path | str],
body_part: Literal["knee", "brain"],
partition: Literal["train", "val", "test"],
mask_fns: Optional[List[Callable]] = None,
sample_rate: float = 1.0,
complex: bool = False,
crop_shape: Tuple[int, int] = (320, 320),
slug: str = "",
contrast: Optional[Literal["T1", "T2"]] = None,
coils: Optional[int] = None,
):
"""
Initializes the fastMRI multi-coil challenge dataset.
Samples are individual 2D slices taken from k-space volume data.
Parameters
----------
body_part : {'knee', 'brain'}
The body part to analyze.
partition : {'train', 'val', 'test'}
The data partition type.
mask_fns : list of callable, optional
A list of masking functions to apply to samples.
If multiple are given, a mask is randomly chosen for each sample.
sample_rate : float, optional
Fraction of data to sample, by default 1.0.
complex : bool, optional
Whether the $k$-space data should return complex-valued, by default False.
If True, kspace values will be complex.
If False, kspace values will be real (shape, 2).
crop_shape : tuple of two ints, optional
The shape to center crop the k-space data, by default (320, 320).
slug : string
dataset slug name
contrast : {'T1', 'T2'}
If partition is brain, the contrast of images to use.
"""
with open("fastmri.yaml", "r") as file:
config = yaml.safe_load(file)
self.contrast = contrast
self.slug = slug
self.partition = partition
self.body_part = body_part
self.root = (
Path(config.get(f"{body_part}_path")) / f"multicoil_{partition}"
)
self.mask_fns = mask_fns
self.sample_rate = sample_rate
self.raw_samples: List[RawSample] = self._load_samples()
self.complex = complex
self.crop_shape = crop_shape
self.coils = coils
def _load_samples(self):
# Gather all files in the root directory
if self.body_part == "brain" and self.contrast:
files = list(self.root.glob(f"*{self.contrast}*.h5"))
else:
files = list(self.root.glob("*.h5"))
raw_samples = []
# Load and process metadata from each file
for fname in sorted(files):
with h5py.File(fname, "r") as hf:
metadata, num_slices = self._retrieve_metadata(fname)
# Collect samples for each slice, discard first c slices, and last c slices
c = 6
for slice_num in range(num_slices):
if c <= slice_num <= num_slices - c - 1:
raw_samples.append(
RawSample(fname, slice_num, metadata)
)
# Subsample if desired
if self.sample_rate < 1.0:
raw_samples = random.sample(
raw_samples, int(len(raw_samples) * self.sample_rate)
)
return raw_samples
def _retrieve_metadata(self, fname):
with h5py.File(fname, "r") as hf:
et_root = etree.fromstring(hf["ismrmrd_header"][()])
enc = ["encoding", "encodedSpace", "matrixSize"]
enc_size = (
int(et_query(et_root, enc + ["x"])),
int(et_query(et_root, enc + ["y"])),
int(et_query(et_root, enc + ["z"])),
)
rec = ["encoding", "reconSpace", "matrixSize"]
recon_size = (
int(et_query(et_root, rec + ["x"])),
int(et_query(et_root, rec + ["y"])),
int(et_query(et_root, rec + ["z"])),
)
lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"]
enc_limits_center = int(et_query(et_root, lims + ["center"]))
enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1
padding_left = enc_size[1] // 2 - enc_limits_center
padding_right = padding_left + enc_limits_max
num_slices = hf["kspace"].shape[0]
metadata = {
"padding_left": padding_left,
"padding_right": padding_right,
"encoding_size": enc_size,
"recon_size": recon_size,
**hf.attrs,
}
return metadata, num_slices
def __len__(self):
return len(self.raw_samples)
def __getitem__(self, idx) -> SliceSample:
try:
raw_sample: RawSample = self.raw_samples[idx]
fname, slice_num, metadata = raw_sample
# load kspace and target
with h5py.File(fname, "r") as hf:
kspace = torch.tensor(hf["kspace"][()][slice_num])
if not self.complex:
kspace = torch.view_as_real(kspace)
if self.coils:
if kspace.shape[0] < self.coils:
return None
kspace = kspace[: self.coils, :, :, :]
target_key = (
"reconstruction_rss"
if self.partition in ["train", "val"]
else "reconstruction_esc"
)
target = hf.get(target_key, None)
if target is not None:
target = torch.tensor(target[()][slice_num])
if self.body_part == "brain":
target = T.center_crop(target, self.crop_shape)
# center crop to enable collating for batching
if self.complex:
# if complex, crop across dims: -2 and -1 (last 2)
raise NotImplementedError("Not implemented for complex native")
else:
# crop in image space, to not lose high-frequency information
image = fastmri.ifft2c(kspace)
image_cropped = T.complex_center_crop(image, self.crop_shape)
kspace = fastmri.fft2c(image_cropped)
# apply transform mask if there is one
if self.mask_fns:
# choose a random mask
mask_fn = random.choice(self.mask_fns)
kspace, mask, num_low_frequencies = T.apply_mask(
kspace,
mask_fn,
# seed=seed,
)
mask = mask.bool()
else:
mask = torch.ones_like(kspace, dtype=torch.bool)
num_low_frequencies = 0
sample = SliceSample(
kspace,
mask,
num_low_frequencies,
target,
metadata["max"],
fname.name,
slice_num,
)
return sample
except:
return None
class SliceDatasetLMDB(torch.utils.data.Dataset):
"""
A simplified PyTorch Dataset that provides access to multicoil MR image
slices from the fastMRI dataset. Loads from LMDB saved samples.
"""
def __init__(
self,
body_part: Literal["knee", "brain"],
partition: Literal["train", "val", "test"],
root: Optional[Path | str] = None,
mask_fns: Optional[List[Callable]] = None,
sample_rate: float = 1.0,
complex: bool = False,
crop_shape: Tuple[int, int] = (320, 320),
slug: str = "",
coils: int = 15,
):
"""
Initializes the fastMRI multi-coil challenge dataset.
Samples are individual 2D slices taken from k-space volume data.
Parameters
----------
body_part : {'knee', 'brain'}
The body part to analyze.
root : Path or str, optional
Root to lmdb dataset. If not provided, the root is automatically
loaded directly from fastmri.yaml config
partition : {'train', 'val', 'test'}
The data partition type.
mask_fns : list of callable, optional
A list of masking functions to apply to samples.
If multiple are given, a mask is randomly chosen for each sample.
sample_rate : float, optional
Fraction of data to sample, by default 1.0.
complex : bool, optional
Whether the $k$-space data should return complex-valued, by default False.
If True, kspace values will be complex.
If False, kspace values will be real (shape, 2).
crop_shape : tuple of two ints, optional
The shape to center crop the k-space data, by default (320, 320).
slug : string
dataset slug name
"""
# set attrs
self.coils = coils
self.slug = slug
self.partition = partition
self.mask_fns = mask_fns
self.sample_rate = sample_rate
self.complex = complex
self.crop_shape = crop_shape
# load lmdb info
if root:
if isinstance(root, str):
root = Path(root)
assert root.exists(), "Provided root doesn't exist."
self.root = root
else:
with open("fastmri.yaml", "r") as file:
config = yaml.safe_load(file)
self.root = Path(config["lmdb"][f"{body_part}_{partition}_path"])
self.meta = np.load(self.root / "meta.npy")
self.kspace_env = lmdb.open(
str(self.root / "kspace"),
readonly=True,
lock=False,
create=False,
)
self.kspace_txn = self.kspace_env.begin(write=False)
self.rss_env = lmdb.open(
str(self.root / "rss"),
readonly=True,
lock=False,
create=False,
)
self.rss_txn = self.rss_env.begin(write=False)
self.length = self.kspace_txn.stat()["entries"]
def __len__(self):
return int(self.sample_rate * self.length)
def __getitem__(self, idx) -> SliceSample:
idx_key = str(idx).encode("utf-8")
# load sample data
kspace = torch.from_numpy(
np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32)
.reshape(self.coils, 320, 320, 2)
.copy()
)
rss = torch.from_numpy(
np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32)
.reshape(320, 320)
.copy()
)
# crop in image space, to not lose high-frequency information
if self.crop_shape and self.crop_shape != (320, 320):
image = fastmri.ifft2c(kspace)
image_cropped = T.complex_center_crop(image, self.crop_shape)
kspace = fastmri.fft2c(image_cropped)
rss = T.center_crop(rss, self.crop_shape)
# load and apply mask
if self.mask_fns:
# choose a random mask
mask_fn = random.choice(self.mask_fns)
kspace, mask, num_low_frequencies = T.apply_mask(
kspace,
mask_fn, # type: ignore
)
mask = mask.bool()
else:
mask = torch.ones_like(kspace, dtype=torch.bool)
num_low_frequencies = 0
# load metadata
fname, slice_num, max_value = self.meta[idx]
fname = str(fname)
slice_num = int(slice_num)
max_value = float(max_value)
return SliceSample(
kspace,
mask,
num_low_frequencies,
rss,
max_value,
fname,
slice_num,
)
class SliceDatasetLMDB_MVUE(torch.utils.data.Dataset):
"""
Loads from LMDB brain saved samples.
Modified to have MVUE targets
"""
def __init__(
self,
root: Path | str,
mask_fns: Optional[List[Callable]] = None,
sample_rate: float = 1.0,
crop_shape: Tuple[int, int] = (320, 320),
slug: str = "",
coils: int = 15,
):
# set attrs
self.coils = coils
self.slug = slug
self.mask_fns = mask_fns
self.sample_rate = sample_rate
self.complex = complex
self.crop_shape = crop_shape
# load lmdb info
if isinstance(root, str):
root = Path(root)
assert root.exists(), "Provided root doesn't exist."
self.root = root
self.meta = np.load(self.root / "meta.npy")
self.mapping = pd.read_csv("brain_mvue_map.csv")
self.kspace_env = lmdb.open(
str(self.root / "kspace"),
readonly=True,
lock=False,
create=False,
)
self.kspace_txn = self.kspace_env.begin(write=False)
self.rss_env = lmdb.open(
str(self.root / "rss"),
readonly=True,
lock=False,
create=False,
)
self.rss_txn = self.rss_env.begin(write=False)
# ray mvue dataset
self.mvue_env = lmdb.open(
str("/pscratch/sd/p/peterwg/datasets/raytemp"),
readonly=True,
lock=False,
create=False,
)
self.mvue_txn = self.mvue_env.begin(write=False)
self.length = len(self.mapping)
# self.length = self.kspace_txn.stat()["entries"]
def __len__(self):
return int(self.sample_rate * self.length)
def __getitem__(self, idx) -> SliceSampleMVUE:
# ray's index: 0-n
ray_idx = idx
# my index: lookup(ray index)
idx = int(self.mapping.iloc[ray_idx].my_index)
ray_idx_key = str(ray_idx).encode("utf-8")
idx_key = str(idx).encode("utf-8")
# load sample data
kspace = torch.from_numpy(
np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32)
.reshape(self.coils, 320, 320, 2)
.copy()
)
# mvue_target = np.sum(
# sp.ifft(kspace, axes=(-1, -2)) * np.conj(s_maps), axis=1
# ) / np.sqrt(np.sum(np.square(np.abs(s_maps)), axis=1))
rss = torch.from_numpy(
np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32)
.reshape(320, 320)
.copy()
)
# load mvue from ray dataset
mvue = torch.from_numpy(
np.frombuffer(self.mvue_txn.get(ray_idx_key), dtype=np.complex64)
.reshape(320, 320)
.copy()
)
mvue = torch.abs(mvue)
# crop in image space, to not lose high-frequency information
if self.crop_shape and self.crop_shape != (320, 320):
image = fastmri.ifft2c(kspace)
image_cropped = T.complex_center_crop(image, self.crop_shape)
kspace = fastmri.fft2c(image_cropped)
rss = T.center_crop(rss, self.crop_shape)
# load and apply mask
if self.mask_fns:
# choose a random mask
mask_fn = random.choice(self.mask_fns)
kspace, mask, num_low_frequencies = T.apply_mask(
kspace,
mask_fn, # type: ignore
)
mask = mask.bool()
else:
mask = torch.ones_like(kspace, dtype=torch.bool)
num_low_frequencies = 0
# load metadata
fname, slice_num, max_value = self.meta[idx]
fname = str(fname)
slice_num = int(slice_num)
max_value = float(max_value)
return SliceSampleMVUE(
kspace,
mask,
num_low_frequencies,
mvue,
rss,
max_value,
fname,
slice_num,
)
# d = SliceDatasetLMDB("knee", "val", None, 1, True, (320, 320), "testdataset")
# print(len(d))
# breakpoint()
# ds = SuperSliceDatasetLMDB(
# "brain", # body_part
# "val", # partition
# None, # root
# None, # mask_fns
# 1.0, # sample_rate
# True, # complex
# (320, 320), # crop_shape
# "test-superres", # slug
# coils=16, # coils
# )
# breakpoint()
# d = SliceDataset("brain", "train", None, contrast="T2")
# # TESTING MVUE
# d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/mri_brain_train_lmdb", coils=16)
# x = d[0]
# d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/raytemp/", coils=16)