nomri / fastmri /losses.py
samaonline
init
1b34a12
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class SSIMLoss(nn.Module):
"""
SSIM loss module.
"""
def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03):
"""
Initialize the Losses class.
Parameters
----------
win_size : int, optional
Window size for SSIM calculation.
k1 : float, optional
k1 parameter for SSIM calculation.
k2 : float, optional
k2 parameter for SSIM calculation.
"""
super().__init__()
self.win_size = win_size
self.k1, self.k2 = k1, k2
self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2)
NP = win_size**2
self.cov_norm = NP / (NP - 1)
def forward(
self,
X: torch.Tensor,
Y: torch.Tensor,
data_range: torch.Tensor,
reduced: bool = True,
):
assert isinstance(self.w, torch.Tensor)
data_range = data_range[:, None, None, None].to(X.device)
C1 = (self.k1 * data_range) ** 2
C2 = (self.k2 * data_range) ** 2
# Compute means
ux = F.conv2d(X, self.w)
uy = F.conv2d(Y, self.w)
# Compute variances
uxx = F.conv2d(X * X, self.w)
uyy = F.conv2d(Y * Y, self.w)
uxy = F.conv2d(X * Y, self.w)
# Compute covariances
vx = self.cov_norm * (uxx - ux * ux)
vy = self.cov_norm * (uyy - uy * uy)
vxy = self.cov_norm * (uxy - ux * uy)
# Compute SSIM components
A1, A2 = 2 * ux * uy + C1, 2 * vxy + C2
B1, B2 = ux**2 + uy**2 + C1, vx + vy + C2
D = B1 * B2
S = (A1 * A2) / D
if reduced:
return 1 - S.mean()
else:
return 1 - S
if __name__ == "__main__":
# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create the SSIMLoss module and move it to the GPU
ssim_loss = SSIMLoss().to(device)
# Create example tensors and move them to the GPU
X = torch.randn(4, 1, 256, 256).to(device)
Y = torch.randn(4, 1, 256, 256).to(device)
data_range = torch.rand(4).to(device)
# Compute the loss
loss = ssim_loss(X, Y, data_range)
print(loss)