Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Modified for use in <TODO: paper name> | |
- minified and removed extraneous abstractions | |
- updated to latest version of lightning | |
Copyright (c) Facebook, Inc. and its affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
from collections import defaultdict | |
from io import BytesIO | |
import pathlib | |
import os | |
from argparse import ArgumentParser | |
from collections import defaultdict | |
import numpy as np | |
import wandb | |
import lightning as L | |
import torch | |
from torchmetrics.metric import Metric | |
import matplotlib | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
matplotlib.use("Agg") | |
from fastmri import evaluate | |
class DistributedMetricSum(Metric): | |
def __init__(self, dist_sync_on_step=True): | |
super().__init__(dist_sync_on_step=dist_sync_on_step) | |
self.add_state( | |
"quantity", default=torch.tensor(0.0), dist_reduce_fx="sum" | |
) | |
def update(self, batch: torch.Tensor): # type: ignore | |
self.quantity += batch | |
def compute(self): | |
return self.quantity | |
class MriModule(L.LightningModule): | |
""" | |
Abstract super class for deep learning reconstruction models. | |
This is a subclass of the LightningModule class from lightning, | |
with some additional functionality specific to fastMRI: | |
- Evaluating reconstructions | |
- Visualization | |
To implement a new reconstruction model, inherit from this class and | |
implement the following methods: | |
- training_step: Define what happens in one step of training | |
- validation_step: Define what happens in one step of validation | |
- test_step: Define what happens in one step of testing | |
- configure_optimizers: Create and return the optimizers | |
Other methods from LightningModule can be overridden as needed. | |
""" | |
def __init__(self, num_log_images: int = 16): | |
""" | |
Initialize the MRI module. | |
Parameters | |
---------- | |
num_log_images : int, optional | |
Number of images to log. Defaults to 16. | |
""" | |
super().__init__() | |
self.num_log_images = num_log_images | |
self.val_log_indices = [1, 2, 3, 4, 5] | |
self.val_batch_results = [] | |
self.NMSE = DistributedMetricSum() | |
self.SSIM = DistributedMetricSum() | |
self.PSNR = DistributedMetricSum() | |
self.ValLoss = DistributedMetricSum() | |
self.TotExamples = DistributedMetricSum() | |
self.TotSliceExamples = DistributedMetricSum() | |
def log_image(self, name, image): | |
if self.logger != None: | |
self.logger.log_image( | |
key=f"{name}", images=[image], caption=[{self.global_step}] | |
) | |
def on_validation_batch_end( | |
self, outputs, batch, batch_idx, dataloader_idx=0 | |
): | |
# breakpoint() | |
val_logs = outputs | |
mse_vals = defaultdict(dict) | |
target_norms = defaultdict(dict) | |
ssim_vals = defaultdict(dict) | |
max_vals = dict() | |
for i, fname in enumerate(val_logs["fname"]): | |
if i == 0 and batch_idx in self.val_log_indices: | |
key = f"val_images_idx_{batch_idx}" | |
target = val_logs["target"][i].unsqueeze(0) | |
output = val_logs["output"][i].unsqueeze(0) | |
error = torch.abs(target - output) | |
output = output / output.max() | |
target = target / target.max() | |
error = error / error.max() | |
self.log_image(f"{key}/target", target) | |
self.log_image(f"{key}/reconstruction", output) | |
self.log_image(f"{key}/error", error) | |
slice_num = int(val_logs["slice_num"][i].cpu()) | |
maxval = val_logs["max_value"][i].cpu().numpy() | |
output = val_logs["output"][i].cpu().numpy() | |
target = val_logs["target"][i].cpu().numpy() | |
mse_vals[fname][slice_num] = torch.tensor( | |
evaluate.mse(target, output) | |
).view(1) | |
target_norms[fname][slice_num] = torch.tensor( | |
evaluate.mse(target, np.zeros_like(target)) | |
).view(1) | |
ssim_vals[fname][slice_num] = torch.tensor( | |
evaluate.ssim( | |
target[None, ...], output[None, ...], maxval=maxval | |
) | |
).view(1) | |
max_vals[fname] = maxval | |
self.val_batch_results.append( | |
{ | |
"slug": val_logs["slug"], | |
"val_loss": val_logs["val_loss"], | |
"mse_vals": dict(mse_vals), | |
"target_norms": dict(target_norms), | |
"ssim_vals": dict(ssim_vals), | |
"max_vals": max_vals, | |
} | |
) | |
def on_validation_epoch_end(self): | |
val_logs = self.val_batch_results | |
dataset_metrics = defaultdict( | |
lambda: { | |
"losses": [], | |
"mse_vals": defaultdict(dict), | |
"target_norms": defaultdict(dict), | |
"ssim_vals": defaultdict(dict), | |
"max_vals": dict(), | |
} | |
) | |
# use dict updates to handle duplicate slices | |
for val_log in val_logs: | |
slug = val_log["slug"] | |
dataset_metrics[slug]["losses"].append(val_log["val_loss"].view(-1)) | |
for k in val_log["mse_vals"].keys(): | |
dataset_metrics[slug]["mse_vals"][k].update( | |
val_log["mse_vals"][k] | |
) | |
for k in val_log["target_norms"].keys(): | |
dataset_metrics[slug]["target_norms"][k].update( | |
val_log["target_norms"][k] | |
) | |
for k in val_log["ssim_vals"].keys(): | |
dataset_metrics[slug]["ssim_vals"][k].update( | |
val_log["ssim_vals"][k] | |
) | |
for k in val_log["max_vals"]: | |
dataset_metrics[slug]["max_vals"][k] = val_log["max_vals"][k] | |
metrics_to_plot = {"psnr": [], "ssim": [], "nmse": []} | |
slugs = [] | |
for slug, metrics_data in dataset_metrics.items(): | |
mse_vals, target_norms, ssim_vals, max_vals, losses = ( | |
metrics_data["mse_vals"], | |
metrics_data["target_norms"], | |
metrics_data["ssim_vals"], | |
metrics_data["max_vals"], | |
metrics_data["losses"], | |
) | |
# check to make sure we have all files in all metrics | |
assert ( | |
mse_vals.keys() | |
== target_norms.keys() | |
== ssim_vals.keys() | |
== max_vals.keys() | |
) | |
# apply means across image volumes | |
metrics = {"nmse": 0, "ssim": 0, "psnr": 0} | |
metric_values = { | |
"nmse": [], | |
"ssim": [], | |
"psnr": [], | |
} # to store individual values for std | |
local_examples = 0 | |
for fname in mse_vals.keys(): | |
local_examples = local_examples + 1 | |
mse_val = torch.mean( | |
torch.cat([v.view(-1) for _, v in mse_vals[fname].items()]) | |
) | |
target_norm = torch.mean( | |
torch.cat( | |
[v.view(-1) for _, v in target_norms[fname].items()] | |
) | |
) | |
nmse = mse_val / target_norm | |
psnr = 20 * torch.log10( | |
torch.tensor( | |
max_vals[fname], | |
dtype=mse_val.dtype, | |
device=mse_val.device, | |
) | |
) - 10 * torch.log10(mse_val) | |
ssim = torch.mean( | |
torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()]) | |
) | |
# Accumulate metric values | |
metrics["nmse"] += nmse | |
metrics["psnr"] += psnr | |
metrics["ssim"] += ssim | |
# Store individual metric values for std calculation | |
metric_values["nmse"].append(nmse) | |
metric_values["psnr"].append(psnr) | |
metric_values["ssim"].append(ssim) | |
# reduce across ddp via sum | |
metrics["nmse"] = self.NMSE(metrics["nmse"]) | |
metrics["ssim"] = self.SSIM(metrics["ssim"]) | |
metrics["psnr"] = self.PSNR(metrics["psnr"]) | |
tot_examples = self.TotExamples(torch.tensor(local_examples)) | |
val_loss = self.ValLoss(torch.sum(torch.cat(losses))) # type: ignore | |
tot_slice_examples = self.TotSliceExamples( | |
torch.tensor(len(losses), dtype=torch.float) | |
) | |
metrics_to_plot["nmse"].append( | |
( | |
(metrics["nmse"] / tot_examples).item(), | |
torch.std(torch.stack(metric_values["nmse"])).item(), | |
) | |
) | |
metrics_to_plot["psnr"].append( | |
( | |
(metrics["psnr"] / tot_examples).item(), | |
torch.std(torch.stack(metric_values["psnr"])).item(), | |
) | |
) | |
metrics_to_plot["ssim"].append( | |
( | |
(metrics["ssim"] / tot_examples).item(), | |
torch.std(torch.stack(metric_values["ssim"])).item(), | |
) | |
) | |
slugs.append(slug) | |
# Log the mean values | |
self.log( | |
f"{slug}--validation_loss", | |
val_loss / tot_slice_examples, | |
prog_bar=True, | |
) | |
for metric, value in metrics.items(): | |
self.log(f"{slug}--val_metrics_{metric}", value / tot_examples) | |
# Calculate and log the standard deviation for each metric | |
for metric, values in metric_values.items(): | |
std_value = torch.std(torch.stack(values)) | |
self.log(f"{slug}--val_metrics_{metric}_std", std_value) | |
# generate graph | |
# breakpoint() | |
for metric_name, values in metrics_to_plot.items(): | |
scores = [val[0] for val in values] | |
std_devs = [val[1] for val in values] | |
plt.figure(figsize=(10, 6)) | |
plt.bar(slugs, scores, yerr=std_devs, capsize=5) | |
plt.xlabel("Dataset Slug") | |
plt.ylabel(f"{metric_name.upper()} Score") | |
plt.title( | |
f"{metric_name.upper()} per Dataset with Standard Deviation" | |
) | |
plt.xticks(rotation=45) | |
plt.tight_layout() | |
# Save the plot | |
buf = BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
image = Image.open(buf) | |
image_array = np.array(image) | |
self.log_image(f"summary_plot_{metric_name}", image_array) | |
buf.close() | |
plt.close() | |
def OLD_on_validation_epoch_end(self): | |
val_logs = self.val_batch_results | |
# aggregate losses | |
losses = [] | |
mse_vals = defaultdict(dict) | |
target_norms = defaultdict(dict) | |
ssim_vals = defaultdict(dict) | |
max_vals = dict() | |
# use dict updates to handle duplicate slices | |
for val_log in val_logs: | |
losses.append(val_log["val_loss"].view(-1)) | |
for k in val_log["mse_vals"].keys(): | |
mse_vals[k].update(val_log["mse_vals"][k]) | |
for k in val_log["target_norms"].keys(): | |
target_norms[k].update(val_log["target_norms"][k]) | |
for k in val_log["ssim_vals"].keys(): | |
ssim_vals[k].update(val_log["ssim_vals"][k]) | |
for k in val_log["max_vals"]: | |
max_vals[k] = val_log["max_vals"][k] | |
# check to make sure we have all files in all metrics | |
assert ( | |
mse_vals.keys() | |
== target_norms.keys() | |
== ssim_vals.keys() | |
== max_vals.keys() | |
) | |
# apply means across image volumes | |
metrics = {"nmse": 0, "ssim": 0, "psnr": 0} | |
local_examples = 0 | |
for fname in mse_vals.keys(): | |
local_examples = local_examples + 1 | |
mse_val = torch.mean( | |
torch.cat([v.view(-1) for _, v in mse_vals[fname].items()]) | |
) | |
target_norm = torch.mean( | |
torch.cat([v.view(-1) for _, v in target_norms[fname].items()]) | |
) | |
metrics["nmse"] = metrics["nmse"] + mse_val / target_norm | |
metrics["psnr"] = ( | |
metrics["psnr"] | |
+ 20 | |
* torch.log10( | |
torch.tensor( | |
max_vals[fname], | |
dtype=mse_val.dtype, | |
device=mse_val.device, | |
) | |
) | |
- 10 * torch.log10(mse_val) | |
) | |
metrics["ssim"] = metrics["ssim"] + torch.mean( | |
torch.cat([v.view(-1) for _, v in ssim_vals[fname].items()]) | |
) | |
# reduce across ddp via sum | |
metrics["nmse"] = self.NMSE(metrics["nmse"]) | |
metrics["ssim"] = self.SSIM(metrics["ssim"]) | |
metrics["psnr"] = self.PSNR(metrics["psnr"]) | |
tot_examples = self.TotExamples(torch.tensor(local_examples)) | |
val_loss = self.ValLoss(torch.sum(torch.cat(losses))) | |
tot_slice_examples = self.TotSliceExamples( | |
torch.tensor(len(losses), dtype=torch.float) | |
) | |
self.log( | |
"validation_loss", val_loss / tot_slice_examples, prog_bar=True | |
) | |
for metric, value in metrics.items(): | |
self.log(f"val_metrics_{metric}", value / tot_examples) | |
def add_model_specific_args(parent_parser): # pragma: no-cover | |
""" | |
Define parameters that only apply to this model | |
""" | |
parser = ArgumentParser(parents=[parent_parser], add_help=False) | |
# logging params | |
parser.add_argument( | |
"--num_log_images", | |
default=16, | |
type=int, | |
help="Number of images to log to Tensorboard", | |
) | |
return parser | |