File size: 487 Bytes
de15dc5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import torch
class BinarizeSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return (input > 0).float() * 2 - 1
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# hardtanh: gradient only in range [-1, 1]
grad_input = grad_output.clone()
grad_input[input.abs() > 1] = 0
return grad_input
def binarize_ste(x):
return BinarizeSTE.apply(x)
|