| """Contains utility math functions. |
| |
| For licensing see accompanying LICENSE file. |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any, Callable, Literal, NamedTuple, Tuple, Union |
|
|
| import torch |
| from torch import autograd |
|
|
| ActivationType = Literal[ |
| "linear", |
| "exp", |
| "sigmoid", |
| "softplus", |
| "relu_with_pushback", |
| "hard_sigmoid_with_pushback", |
| ] |
| ActivationFunction = Callable[[torch.Tensor], torch.Tensor] |
|
|
|
|
| class ActivationPair(NamedTuple): |
| """A pair of forward and inverse activation functions.""" |
|
|
| forward: ActivationFunction |
| inverse: ActivationFunction |
|
|
|
|
| def create_activation_pair(activation_type: ActivationType) -> ActivationPair: |
| """Create activation function and corresponding inverse function. |
| |
| Args: |
| activation_type: The activation type to create. |
| |
| Returns: |
| The corresponding activation functions and the corresponding inverse function. |
| """ |
| if activation_type == "linear": |
| return ActivationPair(lambda x: x, lambda x: x) |
| elif activation_type == "exp": |
| return ActivationPair(torch.exp, torch.log) |
| elif activation_type == "sigmoid": |
| return ActivationPair(torch.sigmoid, inverse_sigmoid) |
| elif activation_type == "softplus": |
| return ActivationPair(torch.nn.functional.softplus, inverse_softplus) |
| elif activation_type == "relu_with_pushback": |
| return ActivationPair(relu_with_pushback, lambda x: x) |
| elif activation_type == "hard_sigmoid_with_pushback": |
| return ActivationPair(hard_sigmoid_with_pushback, lambda x: 6.0 * x - 3.0) |
| else: |
| raise ValueError(f"Unsupported activation function: {activation_type}.") |
|
|
|
|
| def inverse_sigmoid(tensor: torch.Tensor) -> torch.Tensor: |
| """Compute inverse sigmoid.""" |
| return torch.log(tensor / (1.0 - tensor)) |
|
|
|
|
| def inverse_softplus(tensor: torch.Tensor, eps: float = 1e-06) -> torch.Tensor: |
| """Compute inverse softplus.""" |
| tensor = tensor.clamp_min(eps) |
| sigmoid = torch.sigmoid(-tensor) |
| exp = sigmoid / (1.0 - sigmoid) |
| return tensor + torch.log(-exp + 1.0) |
|
|
|
|
| |
| |
| SoftClampRange = Tuple[Union[torch.Tensor, float], Union[torch.Tensor, float]] |
|
|
|
|
| def softclamp( |
| tensor: torch.Tensor, |
| min: SoftClampRange | None = None, |
| max: SoftClampRange | None = None, |
| ) -> torch.Tensor: |
| """Clamp tensor to min/max in differentiable way. |
| |
| Args: |
| tensor: The tensor to clamp. |
| min: Pair of threshold to start clamping and value to clamp to. |
| The first value should be larger than the second. |
| max: Pair of threshold to start clamping and value to clamp to. |
| The first value should be smaller than the second. |
| |
| Returns: |
| The clamped tensor. |
| """ |
|
|
| def normalize(clamp_range: SoftClampRange) -> torch.Tensor: |
| value0, value1 = clamp_range |
| return value0 + (value1 - value0) * torch.tanh((tensor - value0) / (value1 - value0)) |
|
|
| tensor_clamped = tensor |
| if min is not None: |
| tensor_clamped = torch.maximum(tensor_clamped, normalize(min)) |
| if max is not None: |
| tensor_clamped = torch.minimum(tensor_clamped, normalize(max)) |
|
|
| return tensor_clamped |
|
|
|
|
| class ClampWithPushback(autograd.Function): |
| """Implementation of clamp_with_pushback function.""" |
|
|
| @staticmethod |
| def forward( |
| ctx: Any, |
| tensor: torch.Tensor, |
| min: float | None, |
| max: float | None, |
| pushback: float, |
| ) -> torch.Tensor: |
| """Apply clamp.""" |
| if min is not None and max is not None and min >= max: |
| raise ValueError("Only min < max is supported.") |
|
|
| ctx.save_for_backward(tensor) |
| ctx.min = min |
| ctx.max = max |
| ctx.pushback = pushback |
| return torch.clamp(tensor, min=min, max=max) |
|
|
| @staticmethod |
| def backward( |
| ctx: Any, grad_in: torch.Tensor |
| ) -> tuple[torch.Tensor, None, None, None]: |
| """Compute gradient of clamp with pushback.""" |
| grad_out = grad_in.clone() |
| (tensor,) = ctx.saved_tensors |
|
|
| if ctx.min is not None: |
| mask_min = tensor < ctx.min |
| grad_out[mask_min] = -ctx.pushback |
|
|
| if ctx.max is not None: |
| mask_max = tensor > ctx.max |
| grad_out[mask_max] = ctx.pushback |
|
|
| return grad_out, None, None, None |
|
|
|
|
| def clamp_with_pushback( |
| tensor: torch.Tensor, |
| min: float | None = None, |
| max: float | None = None, |
| pushback: float = 1e-2, |
| ) -> torch.Tensor: |
| """Variant of clamp function which avoid the vanishing gradient problem. |
| |
| This function is equivalent to adding a regularizer of the form |
| |
| pushback * sum_i ( |
| relu(min - preactivation_i) + relu(preactivation_i - max) |
| ) |
| |
| to the full loss function, which pushes clamped values back. |
| |
| When used in minimization problems, pushback should be greater than |
| zero. In maximization problems, pushback should be smaller than zero. |
| """ |
| output = ClampWithPushback.apply(tensor, min, max, pushback) |
| assert isinstance(output, torch.Tensor) |
| return output |
|
|
|
|
| def hard_sigmoid_with_pushback(x: torch.Tensor, slope: float = 1.0 / 6.0) -> torch.Tensor: |
| """Apply hard sigmoid with pushback. |
| |
| For compatibility reasons, we follow the default PyTorch implementation with a |
| default slope of 1/6: |
| |
| https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html |
| """ |
| return clamp_with_pushback(slope * x + 0.5, min=0.0, max=1.0) |
|
|
|
|
| def relu_with_pushback(x: torch.Tensor) -> torch.Tensor: |
| """Compute relu with pushback.""" |
| return clamp_with_pushback(x, min=0.0) |
|
|