michaelriedl's picture
Initial dump
002ca81
raw
history blame contribute delete
429 Bytes
import torch
from torch import nn
from .helper_funcs import exists
class Noise(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1))
def forward(self, x, noise=None):
b, _, h, w, device = *x.shape, x.device
if not exists(noise):
noise = torch.randn(b, 1, h, w, device=device)
return x + self.weight * noise