|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
|
|
| import torch
|
|
|
|
|
|
|
| def fma(a, b, c):
|
| return _FusedMultiplyAdd.apply(a, b, c)
|
|
|
|
|
|
|
| class _FusedMultiplyAdd(torch.autograd.Function):
|
| @staticmethod
|
| def forward(ctx, a, b, c):
|
| out = torch.addcmul(c, a, b)
|
| ctx.save_for_backward(a, b)
|
| ctx.c_shape = c.shape
|
| return out
|
|
|
| @staticmethod
|
| def backward(ctx, dout):
|
| a, b = ctx.saved_tensors
|
| c_shape = ctx.c_shape
|
| da = None
|
| db = None
|
| dc = None
|
|
|
| if ctx.needs_input_grad[0]:
|
| da = _unbroadcast(dout * b, a.shape)
|
|
|
| if ctx.needs_input_grad[1]:
|
| db = _unbroadcast(dout * a, b.shape)
|
|
|
| if ctx.needs_input_grad[2]:
|
| dc = _unbroadcast(dout, c_shape)
|
|
|
| return da, db, dc
|
|
|
|
|
|
|
| def _unbroadcast(x, shape):
|
| extra_dims = x.ndim - len(shape)
|
| assert extra_dims >= 0
|
| dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
| if len(dim):
|
| x = x.sum(dim=dim, keepdim=True)
|
| if extra_dims:
|
| x = x.reshape(-1, *x.shape[extra_dims+1:])
|
| assert x.shape == shape
|
| return x
|
|
|
|
|
|
|