import torch | |
from torch import nn | |
from .helper_funcs import exists | |
class Noise(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.weight = nn.Parameter(torch.zeros(1)) | |
def forward(self, x, noise=None): | |
b, _, h, w, device = *x.shape, x.device | |
if not exists(noise): | |
noise = torch.randn(b, 1, h, w, device=device) | |
return x + self.weight * noise | |