OpenSound's picture
Upload 518 files
dd9600d verified
raw
history blame contribute delete
490 Bytes
from torch.autograd import Function
class RevGrad(Function):
@staticmethod
def forward(ctx, input_, alpha_):
ctx.save_for_backward(input_, alpha_)
output = input_
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
grad_input = None
_, alpha_ = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_input = -grad_output * alpha_
return grad_input, None
revgrad = RevGrad.apply