Spaces:
Running
on
Zero
Running
on
Zero
File size: 497 Bytes
dd9600d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from revgrad_func import revgrad
from torch.nn import Module
from torch import tensor
class RevGrad(Module):
def __init__(self, alpha=1., *args, **kwargs):
"""
A gradient reversal layer.
This layer has no parameters, and simply reverses the gradient
in the backward pass.
"""
super().__init__(*args, **kwargs)
self._alpha = tensor(alpha, requires_grad=False)
def forward(self, input_):
return revgrad(input_, self._alpha) |