| import torch | |
| class BinarizeSTE(torch.autograd.Function): | |
| def forward(ctx, input): | |
| ctx.save_for_backward(input) | |
| return (input > 0).float() * 2 - 1 | |
| def backward(ctx, grad_output): | |
| input, = ctx.saved_tensors | |
| # hardtanh: gradient only in range [-1, 1] | |
| grad_input = grad_output.clone() | |
| grad_input[input.abs() > 1] = 0 | |
| return grad_input | |
| def binarize_ste(x): | |
| return BinarizeSTE.apply(x) | |