Source code for scimba_torch.neural_nets.coordinates_based_nets.activation

"""Differents activation layers and adaptive activation layers.

All the activation functions take `**kwargs` for the initialization in order
to have the same signature for all activation functions.
"""

from typing import Any

import torch


[docs] class AdaptativeTanh(torch.nn.Module): """Class for tanh activation function with adaptive parameter. Args: **kwargs: Keyword arguments including: * `mu` (:code:`float`): the mean of the Gaussian law. Defaults to 0.0. * `sigma` (:code:`float`): std of the Gaussian law. Defaults to 0.1. """ def __init__(self, **kwargs: Any): super().__init__() mu = kwargs.get("mu", 0.0) sigma = kwargs.get("sigma", 0.1) self.a = torch.nn.Parameter( torch.randn(()) * sigma + mu ) #: The parameter of the tanh.
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the tanh function. """ exp_p = torch.exp(self.a * x) exp_m = 1 / exp_p return (exp_p - exp_m) / (exp_p + exp_m)
[docs] class Hat(torch.nn.Module): """Class for Hat activation function. Args: **kwargs: Keyword arguments (not used here). """ def __init__(self, **kwargs: Any): super().__init__()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the activation function. """ left_part = torch.relu(1 + x) * (x <= 0) right_part = torch.relu(1 - x) * (x > 0) return left_part + right_part
[docs] class RegularizedHat(torch.nn.Module): """Class for Regularized Hat activation function. Args: **kwargs: Keyword arguments (not used here). """ def __init__(self, **kwargs: Any): super().__init__()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the activation function. """ return torch.exp(-12 * torch.tanh(x**2 / 2))
[docs] class Sine(torch.nn.Module): """Class for Sine activation function. Args: **kwargs: Keyword arguments including: - freq: The frequency of the sinus. Defaults to 1.0. """ def __init__(self, **kwargs: Any): super().__init__() self.freq = kwargs.get("freq", 1.0) #: The frequency of the sinus.
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the sine function. """ return torch.sin(self.freq * x)
[docs] class Cosin(torch.nn.Module): """Class for Cosine activation function. Args: **kwargs: Keyword arguments including: * `freq` (:code:`float`): The frequency of the cosine. Defaults to 1.0. """ def __init__(self, **kwargs: Any): super().__init__() self.freq = kwargs.get("freq", 1.0) #: The frequency of the cosine.
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the cosine function. """ return torch.cos(self.freq * x)
[docs] class Heaviside(torch.nn.Module): r"""Class for Regularized Heaviside activation function. .. math:: H_k(x) &= 1/(1+e^{-2 k x}) \\ k >> 1, \quad H_k(x) &= H(x) Args: **kwargs: Keyword arguments including: * `k` (:code:`float`): the regularization parameter. Defaults to 100.0. """ def __init__(self, **kwargs: Any): super().__init__() self.k = kwargs.get("k", 100.0) #: The regularization parameter.
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application to the sigmoid function. """ return 1.0 / (1.0 + torch.exp(-2.0 * self.k * x))
[docs] class Tanh(torch.nn.Module): """Tanh activation function. Args: **kwargs: Keyword arguments (not used here). """ def __init__(self, **kwargs: Any): super().__init__()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application to the tanh function. """ return torch.tanh(x)
[docs] class Id(torch.nn.Module): r"""Identity activation function.""" def __init__(self): super().__init__()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The input tensor unchanged (identity function). """ return x
[docs] class SiLU(torch.nn.Module): """SiLU activation function. Args: **kwargs: Keyword arguments (not used here). """ def __init__(self, **kwargs: Any): super().__init__() self.ac = torch.nn.SiLU()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the SiLU function. """ return self.ac.forward(x)
[docs] class Swish(torch.nn.Module): """Swish activation function. Args: **kwargs: Keyword arguments including: - learnable: Whether the beta parameter is learnable. Defaults to False. - beta: The beta parameter. Defaults to 1.0. """ def __init__(self, **kwargs: Any): super().__init__() self.learnable = kwargs.get("learnable", False) #: Whether beta is learnable. self.beta = kwargs.get("beta", 1.0) #: The beta parameter. if self.learnable: self.beta = self.a = torch.nn.Parameter(1.0 + torch.randn(()) * 0.1)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the Swish function. """ return x / (1 + torch.exp(-self.beta * x))
[docs] class Sigmoid(torch.nn.Module): """Sigmoid activation function. Args: **kwargs: Keyword arguments (not used here). """ def __init__(self, **kwargs: Any): super().__init__()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the sigmoid function. """ return torch.sigmoid(x)
[docs] class Wavelet(torch.nn.Module): # noqa: D101 pass
[docs] class RbfSinus(torch.nn.Module): # noqa: D101 pass
# Activation function non local to the dimension (we do not apply the same # transformation at each dimension)
[docs] class IsotropicRadial(torch.nn.Module): r"""Isotropic radial basis activation. It is of the form: :math:`\phi(x,m,\sigma)` with :math:`m` the center of the function and :math:`\sigma` the shape parameter. Currently implemented: - :math:`\phi(x,m,\sigma)= exp^{-\mid x-m \mid^2 \sigma^2}` - :math:`\phi(x,m,\sigma)= 1/\sqrt(1+(\mid x-m\mid \sigma^2)^2)` we use the Lp norm. Args: in_size: Size of the inputs. m: Center tensor for the radial basis function. **kwargs: Keyword arguments including: * `norm` (:code:`int`): Number of norm. Defaults to 2. * `type_rbf` (:code:`str`): Type of RBF ("gaussian" or other). Defaults to "gaussian". Learnable Parameters: mu: The list of the center of the radial basis function (size= in_size). sigma: The shape parameter of the radial basis function. """ def __init__(self, in_size: int, m: torch.Tensor, **kwargs): super().__init__() self.dim = in_size self.norm = kwargs.get("norm", 2) m_no_grad = m.detach() self.m = torch.nn.Parameter(m_no_grad) self.sig = torch.nn.Parameter(torch.abs(10 * torch.randn(()) * 0.1 + 0.01)) self.type_rbf = kwargs.get("type_rbf", "gaussian")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the radial basis function. """ norm = torch.norm(x - self.m, p=self.norm, dim=1) ** self.norm norm = norm[:, None] if self.type_rbf == "gaussian": exp_m = torch.exp(-norm / self.sig**2) else: exp_m = 1.0 / (1.0 + (norm * self.sig**2) * 2.0) ** 0.5 return exp_m
[docs] class AnisotropicRadial(torch.nn.Module): r"""Anisotropic radial basis activation. It is of the form: :math:`\phi(x,m,\sigma)` with :math:`m` the center of the function and :math:`\Sigma=A A^t + 0.01 I_d` the matrix shape parameter. Currently implemented: - :math:`\phi(x,m,\Sigma)= exp^{- ((x-m),\Sigma(x-m))}` - :math:`\phi(x,m,\Sigma)= 1/\sqrt(1+((x-m,\Sigma(x-m)))^2)` we use the Lp norm. Args: in_size: Size of the inputs. m: Center tensor for the radial basis function. **kwargs: Keyword arguments including `type_rbf` (:code:`str`): Type of RBF ("gaussian" or other). Defaults to "gaussian". Learnable Parameters: - :code:`mu`: The list of the center of the radial basis function (size= in_size). - :code:`A`: The shape matrix of the radial basis function (size= in_size*in_size). """ def __init__(self, in_size: int, m: torch.Tensor, **kwargs): super().__init__() self.dim = in_size m_no_grad = m.detach() self.m = torch.nn.Parameter(m_no_grad) self.A = torch.nn.Parameter((torch.rand((self.dim, self.dim)))) self.type_rbf = kwargs.get("type_rbf", "gaussian")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the anisotropic radial basis function. """ sid = 0.01 * torch.eye(self.dim, self.dim) sig2 = torch.matmul(torch.transpose(self.A, 0, 1), self.A) + sid norm = torch.linalg.vecdot(torch.mm(x - self.m, sig2), x - self.m, dim=1) norm = norm[:, None] if self.type_rbf == "gaussian": exp_m = torch.exp(-norm) else: exp_m = 1.0 / (1.0 + norm**2) ** 0.5 return exp_m
[docs] class Rational(torch.nn.Module): r"""Class for a rational activation function with adaptive parameters. The function takes the form :math:`P(x) / Q(x)`, with :math:`P` a degree 3 polynomial and :math:`Q` a degree 2 polynomial. It is initialized as the best approximation of the ReLU function on :math:`[- 1, 1]`. The polynomials take the form: - :math:`P(x) = p_0 + p_1 x + p_2 x^2 + p_3 x^3` - :math:`Q(x) = q_0 + q_1 x + q_2 x^2`. ``p0``, ``p1``, ``p2``, ``p3``, ``q0``, ``q1``, ``q2`` are learnable parameters Args: **kwargs: Additional keyword arguments (not used here). """ def __init__(self, **kwargs: Any): super().__init__() # REMI: use torch.tensor instead of torch.Tensor to have it on appropriated # device self.p0 = torch.nn.Parameter( torch.tensor([0.0218]) ) #: Coefficient :math:`p_0` of the polynomial :math:`P`. self.p1 = torch.nn.Parameter( torch.tensor([0.5]) ) #: Coefficient :math:`p_1` of the polynomial :math:`P`. self.p2 = torch.nn.Parameter( torch.tensor([1.5957]) ) #: Coefficient :math:`p_2` of the polynomial :math:`P`. self.p3 = torch.nn.Parameter( torch.tensor([1.1915]) ) #: Coefficient :math:`p_3` of the polynomial :math:`P`. self.q0 = torch.nn.Parameter( torch.tensor([1.0]) ) #: Coefficient :math:`q_0` of the polynomial :math:`Q`. self.q1 = torch.nn.Parameter( torch.tensor([0.0]) ) #: Coefficient :math:`q_1` of the polynomial :math:`Q`. self.q2 = torch.nn.Parameter( torch.tensor([2.3830]) ) #: Coefficient :math:`q_2` of the polynomial :math:`Q`.
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the activation function to a tensor x. Args: x: Input tensor. Returns: The tensor after the application of the rational function. """ P = self.p0 + x * (self.p1 + x * (self.p2 + x * self.p3)) Q = self.q0 + x * (self.q1 + x * self.q2) return P / Q
[docs] def activation_function(ac_type: str, in_size: int = 1, **kwargs): r"""Function to choose the activation function. Args: ac_type: The name of the activation function. in_size: The dimension (useful for radial basis). Defaults to 1. **kwargs: Additional keyword arguments passed to the activation function. Returns: The activation function instance. """ if ac_type == "adaptative_tanh": return AdaptativeTanh(**kwargs) elif ac_type == "sine": return Sine(**kwargs) elif ac_type == "cosin": return Cosin(**kwargs) elif ac_type == "silu": return SiLU(**kwargs) elif ac_type == "swish": return Swish(**kwargs) elif ac_type == "tanh": return Tanh(**kwargs) elif ac_type == "isotropic_radial": return IsotropicRadial(in_size, **kwargs) elif ac_type == "anisotropic_radial": return AnisotropicRadial(in_size, **kwargs) elif ac_type == "sigmoid": return Sigmoid(**kwargs) elif ac_type == "rational": return Rational(**kwargs) elif ac_type == "hat": return Hat(**kwargs) elif ac_type == "regularized_hat": return RegularizedHat(**kwargs) elif ac_type == "heaviside": return Heaviside(**kwargs) else: return Id()