File size: 487 Bytes
de15dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch

class BinarizeSTE(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (input > 0).float() * 2 - 1

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        # hardtanh: gradient only in range [-1, 1]
        grad_input = grad_output.clone()
        grad_input[input.abs() > 1] = 0
        return grad_input

def binarize_ste(x):
    return BinarizeSTE.apply(x)