| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" |
| |
|
| | import torch |
| |
|
| | |
| |
|
| |
|
| | def fma(a, b, c): |
| | return _FusedMultiplyAdd.apply(a, b, c) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class _FusedMultiplyAdd(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, a, b, c): |
| | out = torch.addcmul(c, a, b) |
| | ctx.save_for_backward(a, b) |
| | ctx.c_shape = c.shape |
| | return out |
| |
|
| | @staticmethod |
| | def backward(ctx, dout): |
| | a, b = ctx.saved_tensors |
| | c_shape = ctx.c_shape |
| | da = None |
| | db = None |
| | dc = None |
| |
|
| | if ctx.needs_input_grad[0]: |
| | da = _unbroadcast(dout * b, a.shape) |
| |
|
| | if ctx.needs_input_grad[1]: |
| | db = _unbroadcast(dout * a, b.shape) |
| |
|
| | if ctx.needs_input_grad[2]: |
| | dc = _unbroadcast(dout, c_shape) |
| |
|
| | return da, db, dc |
| |
|
| |
|
| | |
| |
|
| |
|
| | def _unbroadcast(x, shape): |
| | extra_dims = x.ndim - len(shape) |
| | assert extra_dims >= 0 |
| | dim = [ |
| | i for i in range(x.ndim) |
| | if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) |
| | ] |
| | if len(dim): |
| | x = x.sum(dim=dim, keepdim=True) |
| | if extra_dims: |
| | x = x.reshape(-1, *x.shape[extra_dims + 1:]) |
| | assert x.shape == shape |
| | return x |
| |
|
| |
|
| | |
| |
|