Spaces:
Running
on
Zero
Running
on
Zero
File size: 490 Bytes
dd9600d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
from torch.autograd import Function
class RevGrad(Function):
@staticmethod
def forward(ctx, input_, alpha_):
ctx.save_for_backward(input_, alpha_)
output = input_
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
grad_input = None
_, alpha_ = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = -grad_output * alpha_
return grad_input, None
revgrad = RevGrad.apply |