Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) Facebook, Inc. and its affiliates. | |
This source code is licensed under the MIT license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SSIMLoss(nn.Module): | |
""" | |
SSIM loss module. | |
""" | |
def __init__(self, win_size: int = 7, k1: float = 0.01, k2: float = 0.03): | |
""" | |
Initialize the Losses class. | |
Parameters | |
---------- | |
win_size : int, optional | |
Window size for SSIM calculation. | |
k1 : float, optional | |
k1 parameter for SSIM calculation. | |
k2 : float, optional | |
k2 parameter for SSIM calculation. | |
""" | |
super().__init__() | |
self.win_size = win_size | |
self.k1, self.k2 = k1, k2 | |
self.register_buffer("w", torch.ones(1, 1, win_size, win_size) / win_size**2) | |
NP = win_size**2 | |
self.cov_norm = NP / (NP - 1) | |
def forward( | |
self, | |
X: torch.Tensor, | |
Y: torch.Tensor, | |
data_range: torch.Tensor, | |
reduced: bool = True, | |
): | |
assert isinstance(self.w, torch.Tensor) | |
data_range = data_range[:, None, None, None].to(X.device) | |
C1 = (self.k1 * data_range) ** 2 | |
C2 = (self.k2 * data_range) ** 2 | |
# Compute means | |
ux = F.conv2d(X, self.w) | |
uy = F.conv2d(Y, self.w) | |
# Compute variances | |
uxx = F.conv2d(X * X, self.w) | |
uyy = F.conv2d(Y * Y, self.w) | |
uxy = F.conv2d(X * Y, self.w) | |
# Compute covariances | |
vx = self.cov_norm * (uxx - ux * ux) | |
vy = self.cov_norm * (uyy - uy * uy) | |
vxy = self.cov_norm * (uxy - ux * uy) | |
# Compute SSIM components | |
A1, A2 = 2 * ux * uy + C1, 2 * vxy + C2 | |
B1, B2 = ux**2 + uy**2 + C1, vx + vy + C2 | |
D = B1 * B2 | |
S = (A1 * A2) / D | |
if reduced: | |
return 1 - S.mean() | |
else: | |
return 1 - S | |
if __name__ == "__main__": | |
# Example usage | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Create the SSIMLoss module and move it to the GPU | |
ssim_loss = SSIMLoss().to(device) | |
# Create example tensors and move them to the GPU | |
X = torch.randn(4, 1, 256, 256).to(device) | |
Y = torch.randn(4, 1, 256, 256).to(device) | |
data_range = torch.rand(4).to(device) | |
# Compute the loss | |
loss = ssim_loss(X, Y, data_range) | |
print(loss) | |