| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" |
|
|
| import torch |
|
|
| |
|
|
|
|
| def fma(a, b, c): |
| return _FusedMultiplyAdd.apply(a, b, c) |
|
|
|
|
| |
|
|
|
|
| class _FusedMultiplyAdd(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, a, b, c): |
| out = torch.addcmul(c, a, b) |
| ctx.save_for_backward(a, b) |
| ctx.c_shape = c.shape |
| return out |
|
|
| @staticmethod |
| def backward(ctx, dout): |
| a, b = ctx.saved_tensors |
| c_shape = ctx.c_shape |
| da = None |
| db = None |
| dc = None |
|
|
| if ctx.needs_input_grad[0]: |
| da = _unbroadcast(dout * b, a.shape) |
|
|
| if ctx.needs_input_grad[1]: |
| db = _unbroadcast(dout * a, b.shape) |
|
|
| if ctx.needs_input_grad[2]: |
| dc = _unbroadcast(dout, c_shape) |
|
|
| return da, db, dc |
|
|
|
|
| |
|
|
|
|
| def _unbroadcast(x, shape): |
| extra_dims = x.ndim - len(shape) |
| assert extra_dims >= 0 |
| dim = [ |
| i for i in range(x.ndim) |
| if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1) |
| ] |
| if len(dim): |
| x = x.sum(dim=dim, keepdim=True) |
| if extra_dims: |
| x = x.reshape(-1, *x.shape[extra_dims + 1:]) |
| assert x.shape == shape |
| return x |
|
|
|
|
| |
|
|