File size: 8,242 Bytes
1b34a12
 
 
 
 
 
56dc12e
6c5f607
1b34a12
 
 
 
 
56dc12e
 
1b34a12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56dc12e
1b34a12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79136bd
1b34a12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26c0d6d
 
 
1b34a12
e59b8c6
 
 
1b34a12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4416b5c
1b34a12
 
4416b5c
1b34a12
 
 
 
 
 
4416b5c
1b34a12
4416b5c
1b34a12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827ea96
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import io
import os
import sys

import gradio as gr
import numpy as np
import spaces
import torch

# from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download
from PIL import Image, ImageDraw, ImageFont

zero = torch.Tensor([0]).cuda()

# Set the working directory to the root directory
# root_dir = os.path.abspath("..")
# os.chdir(root_dir)
# sys.path.insert(0, root_dir)

# download dataset & weights
snapshot_download(repo_id="armeet/fastmri-tiny", repo_type="dataset", local_dir=".")


device = "cuda"
# dataset_path = "/global/homes/p/peterwg/pscratch/datasets/mri_knee_dummy"
dataset_path = "dataset"

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

import fastmri
from fastmri.datasets import SliceDatasetLMDB, SliceSample
from fastmri.subsample import create_mask_for_mask_type
from models.lightning.no_varnet_module import NOVarnetModule
from models.lightning.varnet_module import VarNetModule

acceleration_to_fractions = {
    1: 1,
    2: 0.16,
    4: 0.08,
    6: 0.06,
    8: 0.04,
    16: 0.02,
    32: 0.01,
}


def create_mask_fn(center_fraction, acceleration):
    mask_fn = create_mask_for_mask_type(
        "equispaced_fraction",
        [center_fraction],
        [acceleration],
    )
    return mask_fn


mask_4x = create_mask_fn(acceleration_to_fractions[4], 4)
mask_6x = create_mask_fn(acceleration_to_fractions[6], 6)
mask_8x = create_mask_fn(acceleration_to_fractions[8], 8)
mask_16x = create_mask_fn(acceleration_to_fractions[16], 16)

val_dataset_4x = SliceDatasetLMDB(
    "knee",
    partition="val",
    mask_fns=[mask_4x],
    complex=False,
    root=dataset_path,
    crop_shape=(320, 320),
    coils=15,
)

val_dataset_6x = SliceDatasetLMDB(
    "knee",
    partition="val",
    mask_fns=[mask_6x],
    complex=False,
    root=dataset_path,
    crop_shape=(320, 320),
    coils=15,
)

val_dataset_8x = SliceDatasetLMDB(
    "knee",
    partition="val",
    mask_fns=[mask_8x],
    complex=False,
    root=dataset_path,
    crop_shape=(320, 320),
    coils=15,
)
val_dataset_16x = SliceDatasetLMDB(
    "knee",
    partition="val",
    mask_fns=[mask_16x],
    complex=False,
    root=dataset_path,
    crop_shape=(320, 320),
    coils=15,
)

vn = VarNetModule.load_from_checkpoint(
    "vn.ckpt",
)
no = NOVarnetModule.load_from_checkpoint(
    "no.ckpt",
)
no.eval()
vn.eval()

bright_samples = [42, 69, 80, 137, 139, 226, 229]


def v(x):
    return x.detach().cpu().numpy().squeeze()


def viz(x, cmap="gray", vmin=0, vmax=1):
    processed_data = v(x)
    fig, ax = plt.subplots()
    ax.imshow(processed_data, cmap=cmap, vmin=vmin, vmax=vmax)
    ax.axis("off")  # Turn off axes
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0)  # Adjust margins
    buf = io.BytesIO()
    plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
    buf.seek(0)  # Rewind the buffer to the beginning
    plt.show()
    try:
        img = Image.open(buf)
        img_array = np.array(img)
    except Exception as e:
        print(f"Error converting image buffer to NumPy array: {e}")
        img_array = None
    finally:
        plt.close(fig)
        buf.close()
    return img_array


@spaces.GPU
def forward(model, idx, rate):
    if rate == 4:
        dataset = val_dataset_4x
    elif rate == 6:
        dataset = val_dataset_6x
    elif rate == 8:
        dataset = val_dataset_8x
    elif rate == 16:
        dataset = val_dataset_16x
    else:
        raise ValueError("Invalid rate")

    sample = dataset[idx]
    mask, k, target = (
        sample.mask.to(device),
        sample.masked_kspace.to(device),
        sample.target.to(device),
    )
    pred = model(k.unsqueeze(0), mask.unsqueeze(0), None)

    return mask, k, target, pred[0]


@spaces.GPU
def update_interface(sample_id, sample_rate):
    n = [None] * 6
    if sample_id is None or sample_rate is None or sample_id not in bright_samples:
        return n

    mask, k, target, pred_vn = forward(vn, sample_id, sample_rate)
    _, _, _, pred_no = forward(no, sample_id, sample_rate)

    k = viz(mask[0, :, :, 0], cmap="gray", vmin=0, vmax=1)
    target_res = viz(target, cmap="gray", vmin=None, vmax=None)

    pred_no_res = viz(pred_no, cmap="gray", vmin=None, vmax=None)
    pred_vn_res = viz(pred_vn, cmap="gray", vmin=None, vmax=None)

    diff_no_res = viz(torch.abs(pred_no - target), cmap=None, vmin=None, vmax=None)
    diff_vn_res = viz(torch.abs(pred_vn - target), cmap=None, vmin=None, vmax=None)

    return k, target_res, pred_no_res, pred_vn_res, diff_no_res, diff_vn_res


with gr.Blocks(theme=gr.themes.Monochrome(), fill_width=True) as demo:
    gr.Markdown(
        "# A Unified Model for Compressed Sensing MRI Across Undersampling Patterns [CPVR 2025]"
    )
    gr.Markdown("""  
> Armeet Singh Jatyani, Jiayun Wang, Aditi Chandrashekar, Zihui Wu, Miguel Liu-Schiaffini, Bahareh Tolooshams, Anima Anandkumar  
                """)
    gr.Markdown(
        "[![arXiv](https://img.shields.io/badge/arXiv-2410.16290-b31b1b.svg?style=flat-square&logo=arxiv)](https://arxiv.org/abs/2410.16290)"
    )
    gr.Markdown(
        "[![](https://img.shields.io/badge/Blog-armeet.ca%2Fnomri-yellow?style=flat-square)](https://armeet.ca/nomri)"
    )

    gr.Markdown(
        "This demo showcases the performance of our unified model for compressed sensing MRI across different acceleration rates."
    )
    gr.Markdown(
        "We recommend trying samples with a 16x acceleration pattern first, as reconstruction differences are easy to observe."
    )

    gr.Markdown(
        "At lower acceleration rates (4x or 6x), the difference in reconstruction quality is difficult to discern. At higher acceleration rates, look for blurring, repeating, or distortion, especially near edges and in backgrounds. We provide difference images to help identify reconstruction errors."
    )
    with gr.Row():
        dropdown_sample = gr.Dropdown(
            choices=bright_samples,
            label="Select a Sample",
            info="Choose one of the available samples.",
            filterable=False,
            value=229,
        )
    with gr.Row():
        dropdown_rate = gr.Radio(
            choices=[16, 8, 6, 4],
            value=16,
            label="Select an Acceleration Rate",
            info="Ex: 4x means the model is trained to reconstruct from 4x undersampled k-space data",
            # filterable=False,
        )

    with gr.Row():
        with gr.Column():
            gr.Label("Undersampling Mask")
            k = gr.Image(label=None, interactive=False)
        with gr.Column():
            gr.Label("Ground Truth")
            target = gr.Image(label=None, interactive=False)
        with gr.Column():
            gr.Label("NO (ours)")
            pred_no = gr.Image(label="Reconstruction (ours)", interactive=False)
        with gr.Column():
            gr.Label("VN (existing)")
            pred_vn = gr.Image(label="Reconstruction (existing)", interactive=False)
    with gr.Row():
        with gr.Column():
            pass
        with gr.Column():
            pass
        with gr.Column():
            diff_no = gr.Image(label="| Recon - GT | (ours)", interactive=False)
        with gr.Column():
            diff_vn = gr.Image(label="| Recon - GT | (existing)", interactive=False)

    gr.Markdown("""
```
@inproceedings{jatyani2025nomri,
  author    = {Armeet Singh Jatyani* and Jiayun Wang* and Aditi Chandrashekar and Zihui Wu and Miguel Liu-Schiaffini and Bahareh Tolooshams and Anima Anandkumar},
  title     = {A Unified Model for Compressed Sensing MRI Across Undersampling Patterns},
  booktitle = {Conference on Computer Vision and Pattern Recognition (CVPR) Proceedings},
  abbr      = {CVPR},
  year      = {2025}
}
```
                """)

    update_inputs = [dropdown_sample, dropdown_rate]
    update_outputs = [k, target, pred_no, pred_vn, diff_no, diff_vn]

    dropdown_sample.change(
        fn=update_interface, inputs=update_inputs, outputs=update_outputs
    )
    dropdown_rate.change(
        fn=update_interface, inputs=update_inputs, outputs=update_outputs
    )

if __name__ == "__main__":
    # demo.launch(share=True)
    demo.launch()