test / modules /binarize_ste.py
jaewooo's picture
Initial upload
de15dc5 verified
raw
history blame contribute delete
487 Bytes
import torch
class BinarizeSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return (input > 0).float() * 2 - 1
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# hardtanh: gradient only in range [-1, 1]
grad_input = grad_output.clone()
grad_input[input.abs() > 1] = 0
return grad_input
def binarize_ste(x):
return BinarizeSTE.apply(x)