import torch class BinarizeSTE(torch.autograd.Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return (input > 0).float() * 2 - 1 @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors # hardtanh: gradient only in range [-1, 1] grad_input = grad_output.clone() grad_input[input.abs() > 1] = 0 return grad_input def binarize_ste(x): return BinarizeSTE.apply(x)