Spaces:
Running
on
Zero
Running
on
Zero
from torch.autograd import Function | |
class RevGrad(Function): | |
def forward(ctx, input_, alpha_): | |
ctx.save_for_backward(input_, alpha_) | |
output = input_ | |
return output | |
def backward(ctx, grad_output): # pragma: no cover | |
grad_input = None | |
_, alpha_ = ctx.saved_tensors | |
if ctx.needs_input_grad[0]: | |
grad_input = -grad_output * alpha_ | |
return grad_input, None | |
revgrad = RevGrad.apply |