import random import xml.etree.ElementTree as etree from pathlib import Path from typing import ( Any, Callable, Dict, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union, ) import h5py import lmdb import numpy as np import torch import yaml import sigpy as sp import pandas as pd import fastmri import fastmri.transforms as T class RawSample(NamedTuple): fname: Path slice_num: int metadata: Dict[str, Any] class SliceSample(NamedTuple): masked_kspace: torch.Tensor mask: torch.Tensor num_low_frequencies: int target: torch.Tensor max_value: float # attrs: Dict[str, Any] fname: str slice_num: int class SliceSampleMVUE(NamedTuple): masked_kspace: torch.Tensor mask: torch.Tensor num_low_frequencies: int target: torch.Tensor rss: torch.Tensor max_value: float # attrs: Dict[str, Any] fname: str slice_num: int def et_query( root: etree.Element, qlist: Sequence[str], namespace: str = "http://www.ismrm.org/ISMRMRD", ) -> str: """ Query an XML document using ElementTree. This function allows querying an XML document by specifying a root and a list of nested queries. It supports optional XML namespaces. Parameters ---------- root : ElementTree.Element The root element of the XML to search through. qlist : list of str A list of strings for nested searches, e.g., ["Encoding", "matrixSize"]. namespace : str, optional An optional XML namespace to prepend to the query (default is None). Returns ------- str The retrieved data as a string. """ s = "." prefix = "ismrmrd_namespace" ns = {prefix: namespace} for el in qlist: s = s + f"//{prefix}:{el}" value = root.find(s, ns) if value is None: raise RuntimeError("Element not found") return str(value.text) class SliceDataset(torch.utils.data.Dataset): """ A simplified PyTorch Dataset that provides access to multicoil MR image slices from the fastMRI dataset. """ def __init__( self, # root: Optional[Path | str], body_part: Literal["knee", "brain"], partition: Literal["train", "val", "test"], mask_fns: Optional[List[Callable]] = None, sample_rate: float = 1.0, complex: bool = False, crop_shape: Tuple[int, int] = (320, 320), slug: str = "", contrast: Optional[Literal["T1", "T2"]] = None, coils: Optional[int] = None, ): """ Initializes the fastMRI multi-coil challenge dataset. Samples are individual 2D slices taken from k-space volume data. Parameters ---------- body_part : {'knee', 'brain'} The body part to analyze. partition : {'train', 'val', 'test'} The data partition type. mask_fns : list of callable, optional A list of masking functions to apply to samples. If multiple are given, a mask is randomly chosen for each sample. sample_rate : float, optional Fraction of data to sample, by default 1.0. complex : bool, optional Whether the $k$-space data should return complex-valued, by default False. If True, kspace values will be complex. If False, kspace values will be real (shape, 2). crop_shape : tuple of two ints, optional The shape to center crop the k-space data, by default (320, 320). slug : string dataset slug name contrast : {'T1', 'T2'} If partition is brain, the contrast of images to use. """ with open("fastmri.yaml", "r") as file: config = yaml.safe_load(file) self.contrast = contrast self.slug = slug self.partition = partition self.body_part = body_part self.root = ( Path(config.get(f"{body_part}_path")) / f"multicoil_{partition}" ) self.mask_fns = mask_fns self.sample_rate = sample_rate self.raw_samples: List[RawSample] = self._load_samples() self.complex = complex self.crop_shape = crop_shape self.coils = coils def _load_samples(self): # Gather all files in the root directory if self.body_part == "brain" and self.contrast: files = list(self.root.glob(f"*{self.contrast}*.h5")) else: files = list(self.root.glob("*.h5")) raw_samples = [] # Load and process metadata from each file for fname in sorted(files): with h5py.File(fname, "r") as hf: metadata, num_slices = self._retrieve_metadata(fname) # Collect samples for each slice, discard first c slices, and last c slices c = 6 for slice_num in range(num_slices): if c <= slice_num <= num_slices - c - 1: raw_samples.append( RawSample(fname, slice_num, metadata) ) # Subsample if desired if self.sample_rate < 1.0: raw_samples = random.sample( raw_samples, int(len(raw_samples) * self.sample_rate) ) return raw_samples def _retrieve_metadata(self, fname): with h5py.File(fname, "r") as hf: et_root = etree.fromstring(hf["ismrmrd_header"][()]) enc = ["encoding", "encodedSpace", "matrixSize"] enc_size = ( int(et_query(et_root, enc + ["x"])), int(et_query(et_root, enc + ["y"])), int(et_query(et_root, enc + ["z"])), ) rec = ["encoding", "reconSpace", "matrixSize"] recon_size = ( int(et_query(et_root, rec + ["x"])), int(et_query(et_root, rec + ["y"])), int(et_query(et_root, rec + ["z"])), ) lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"] enc_limits_center = int(et_query(et_root, lims + ["center"])) enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1 padding_left = enc_size[1] // 2 - enc_limits_center padding_right = padding_left + enc_limits_max num_slices = hf["kspace"].shape[0] metadata = { "padding_left": padding_left, "padding_right": padding_right, "encoding_size": enc_size, "recon_size": recon_size, **hf.attrs, } return metadata, num_slices def __len__(self): return len(self.raw_samples) def __getitem__(self, idx) -> SliceSample: try: raw_sample: RawSample = self.raw_samples[idx] fname, slice_num, metadata = raw_sample # load kspace and target with h5py.File(fname, "r") as hf: kspace = torch.tensor(hf["kspace"][()][slice_num]) if not self.complex: kspace = torch.view_as_real(kspace) if self.coils: if kspace.shape[0] < self.coils: return None kspace = kspace[: self.coils, :, :, :] target_key = ( "reconstruction_rss" if self.partition in ["train", "val"] else "reconstruction_esc" ) target = hf.get(target_key, None) if target is not None: target = torch.tensor(target[()][slice_num]) if self.body_part == "brain": target = T.center_crop(target, self.crop_shape) # center crop to enable collating for batching if self.complex: # if complex, crop across dims: -2 and -1 (last 2) raise NotImplementedError("Not implemented for complex native") else: # crop in image space, to not lose high-frequency information image = fastmri.ifft2c(kspace) image_cropped = T.complex_center_crop(image, self.crop_shape) kspace = fastmri.fft2c(image_cropped) # apply transform mask if there is one if self.mask_fns: # choose a random mask mask_fn = random.choice(self.mask_fns) kspace, mask, num_low_frequencies = T.apply_mask( kspace, mask_fn, # seed=seed, ) mask = mask.bool() else: mask = torch.ones_like(kspace, dtype=torch.bool) num_low_frequencies = 0 sample = SliceSample( kspace, mask, num_low_frequencies, target, metadata["max"], fname.name, slice_num, ) return sample except: return None class SliceDatasetLMDB(torch.utils.data.Dataset): """ A simplified PyTorch Dataset that provides access to multicoil MR image slices from the fastMRI dataset. Loads from LMDB saved samples. """ def __init__( self, body_part: Literal["knee", "brain"], partition: Literal["train", "val", "test"], root: Optional[Path | str] = None, mask_fns: Optional[List[Callable]] = None, sample_rate: float = 1.0, complex: bool = False, crop_shape: Tuple[int, int] = (320, 320), slug: str = "", coils: int = 15, ): """ Initializes the fastMRI multi-coil challenge dataset. Samples are individual 2D slices taken from k-space volume data. Parameters ---------- body_part : {'knee', 'brain'} The body part to analyze. root : Path or str, optional Root to lmdb dataset. If not provided, the root is automatically loaded directly from fastmri.yaml config partition : {'train', 'val', 'test'} The data partition type. mask_fns : list of callable, optional A list of masking functions to apply to samples. If multiple are given, a mask is randomly chosen for each sample. sample_rate : float, optional Fraction of data to sample, by default 1.0. complex : bool, optional Whether the $k$-space data should return complex-valued, by default False. If True, kspace values will be complex. If False, kspace values will be real (shape, 2). crop_shape : tuple of two ints, optional The shape to center crop the k-space data, by default (320, 320). slug : string dataset slug name """ # set attrs self.coils = coils self.slug = slug self.partition = partition self.mask_fns = mask_fns self.sample_rate = sample_rate self.complex = complex self.crop_shape = crop_shape # load lmdb info if root: if isinstance(root, str): root = Path(root) assert root.exists(), "Provided root doesn't exist." self.root = root else: with open("fastmri.yaml", "r") as file: config = yaml.safe_load(file) self.root = Path(config["lmdb"][f"{body_part}_{partition}_path"]) self.meta = np.load(self.root / "meta.npy") self.kspace_env = lmdb.open( str(self.root / "kspace"), readonly=True, lock=False, create=False, ) self.kspace_txn = self.kspace_env.begin(write=False) self.rss_env = lmdb.open( str(self.root / "rss"), readonly=True, lock=False, create=False, ) self.rss_txn = self.rss_env.begin(write=False) self.length = self.kspace_txn.stat()["entries"] def __len__(self): return int(self.sample_rate * self.length) def __getitem__(self, idx) -> SliceSample: idx_key = str(idx).encode("utf-8") # load sample data kspace = torch.from_numpy( np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32) .reshape(self.coils, 320, 320, 2) .copy() ) rss = torch.from_numpy( np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32) .reshape(320, 320) .copy() ) # crop in image space, to not lose high-frequency information if self.crop_shape and self.crop_shape != (320, 320): image = fastmri.ifft2c(kspace) image_cropped = T.complex_center_crop(image, self.crop_shape) kspace = fastmri.fft2c(image_cropped) rss = T.center_crop(rss, self.crop_shape) # load and apply mask if self.mask_fns: # choose a random mask mask_fn = random.choice(self.mask_fns) kspace, mask, num_low_frequencies = T.apply_mask( kspace, mask_fn, # type: ignore ) mask = mask.bool() else: mask = torch.ones_like(kspace, dtype=torch.bool) num_low_frequencies = 0 # load metadata fname, slice_num, max_value = self.meta[idx] fname = str(fname) slice_num = int(slice_num) max_value = float(max_value) return SliceSample( kspace, mask, num_low_frequencies, rss, max_value, fname, slice_num, ) class SliceDatasetLMDB_MVUE(torch.utils.data.Dataset): """ Loads from LMDB brain saved samples. Modified to have MVUE targets """ def __init__( self, root: Path | str, mask_fns: Optional[List[Callable]] = None, sample_rate: float = 1.0, crop_shape: Tuple[int, int] = (320, 320), slug: str = "", coils: int = 15, ): # set attrs self.coils = coils self.slug = slug self.mask_fns = mask_fns self.sample_rate = sample_rate self.complex = complex self.crop_shape = crop_shape # load lmdb info if isinstance(root, str): root = Path(root) assert root.exists(), "Provided root doesn't exist." self.root = root self.meta = np.load(self.root / "meta.npy") self.mapping = pd.read_csv("brain_mvue_map.csv") self.kspace_env = lmdb.open( str(self.root / "kspace"), readonly=True, lock=False, create=False, ) self.kspace_txn = self.kspace_env.begin(write=False) self.rss_env = lmdb.open( str(self.root / "rss"), readonly=True, lock=False, create=False, ) self.rss_txn = self.rss_env.begin(write=False) # ray mvue dataset self.mvue_env = lmdb.open( str("/pscratch/sd/p/peterwg/datasets/raytemp"), readonly=True, lock=False, create=False, ) self.mvue_txn = self.mvue_env.begin(write=False) self.length = len(self.mapping) # self.length = self.kspace_txn.stat()["entries"] def __len__(self): return int(self.sample_rate * self.length) def __getitem__(self, idx) -> SliceSampleMVUE: # ray's index: 0-n ray_idx = idx # my index: lookup(ray index) idx = int(self.mapping.iloc[ray_idx].my_index) ray_idx_key = str(ray_idx).encode("utf-8") idx_key = str(idx).encode("utf-8") # load sample data kspace = torch.from_numpy( np.frombuffer(self.kspace_txn.get(idx_key), dtype=np.float32) .reshape(self.coils, 320, 320, 2) .copy() ) # mvue_target = np.sum( # sp.ifft(kspace, axes=(-1, -2)) * np.conj(s_maps), axis=1 # ) / np.sqrt(np.sum(np.square(np.abs(s_maps)), axis=1)) rss = torch.from_numpy( np.frombuffer(self.rss_txn.get(idx_key), dtype=np.float32) .reshape(320, 320) .copy() ) # load mvue from ray dataset mvue = torch.from_numpy( np.frombuffer(self.mvue_txn.get(ray_idx_key), dtype=np.complex64) .reshape(320, 320) .copy() ) mvue = torch.abs(mvue) # crop in image space, to not lose high-frequency information if self.crop_shape and self.crop_shape != (320, 320): image = fastmri.ifft2c(kspace) image_cropped = T.complex_center_crop(image, self.crop_shape) kspace = fastmri.fft2c(image_cropped) rss = T.center_crop(rss, self.crop_shape) # load and apply mask if self.mask_fns: # choose a random mask mask_fn = random.choice(self.mask_fns) kspace, mask, num_low_frequencies = T.apply_mask( kspace, mask_fn, # type: ignore ) mask = mask.bool() else: mask = torch.ones_like(kspace, dtype=torch.bool) num_low_frequencies = 0 # load metadata fname, slice_num, max_value = self.meta[idx] fname = str(fname) slice_num = int(slice_num) max_value = float(max_value) return SliceSampleMVUE( kspace, mask, num_low_frequencies, mvue, rss, max_value, fname, slice_num, ) # d = SliceDatasetLMDB("knee", "val", None, 1, True, (320, 320), "testdataset") # print(len(d)) # breakpoint() # ds = SuperSliceDatasetLMDB( # "brain", # body_part # "val", # partition # None, # root # None, # mask_fns # 1.0, # sample_rate # True, # complex # (320, 320), # crop_shape # "test-superres", # slug # coils=16, # coils # ) # breakpoint() # d = SliceDataset("brain", "train", None, contrast="T2") # # TESTING MVUE # d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/mri_brain_train_lmdb", coils=16) # x = d[0] # d = SliceDatasetLMDB_MVUE("/pscratch/sd/p/peterwg/datasets/raytemp/", coils=16)