Source code for scimba_torch.neural_nets.coordinates_based_nets.pirate_net

"""PirateNet architecture implementation."""

import torch
from torch import nn

from scimba_torch.neural_nets.coordinates_based_nets.features import EnhancedFeatureNet
from scimba_torch.neural_nets.coordinates_based_nets.scimba_module import ScimbaModule

from .activation import activation_function


[docs] class PirateNetBlock(ScimbaModule): """Implements a block of the PirateNet. Each block applies three linear transformations with activation to compute weighting matrices U and V, then updates the input `x` by combining these matrices using a residual scheme. Args: dim: Input and output dimension of the block, default is 1 **kwargs: Additional parameters for block configuration """ def __init__(self, dim: int = 1, **kwargs): super().__init__(dim, dim, **kwargs) self.in_size = dim self.out_size_embedded = dim #: Linear layer for the `f_l` transformation self.W_f = nn.Linear(dim, dim) #: Linear layer for the `g_l` transformation self.W_g = nn.Linear(dim, dim) #: Linear layer for the `h_l` transformation self.W_h = nn.Linear(dim, dim) #: Trainable parameter for mixing the old and new value of `x` self.alpha = nn.Parameter(torch.tensor(0.1)) # Trainable parameter self.activation_type = kwargs.get("activation_type", "tanh") #: Activation function used in the block self.activation = activation_function( self.activation_type, in_size=dim, **kwargs )
[docs] def forward( self, x: torch.Tensor, u: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: """Applies the block transformation to the input `x`. Args: x: Input of the block u: Weighting matrix U v: Weighting matrix V Returns: Output of the block after transformation """ f_l = self.activation(self.W_f(x)) z_l = f_l * u + (1 - f_l) * v g_l = self.activation(self.W_g(z_l)) e_l = g_l * u + (1 - g_l) * v h_l = self.activation(self.W_h(e_l)) x_next = self.alpha * h_l + (1 - self.alpha) * x return x_next
[docs] class PirateNet(ScimbaModule): """A PirateNet neural network implementation. Args: in_size: Input dimension, default is 1 out_size: Output dimension, default is 1 nb_features: Number of features used for encoding, default is 1 nb_blocks: Number of stacked `PiranteNet_block` layers, default is 1 **kwargs: Additional parameters for network configuration """ def __init__( self, in_size: int = 1, out_size: int = 1, nb_features: int = 1, nb_blocks: int = 1, **kwargs, ): super().__init__(in_size=in_size, out_size=out_size, **kwargs) activation_type = kwargs.get("activation_type", "tanh") activation_output = kwargs.get("activation_output", "id") last_layer_has_bias = kwargs.get("last_layer_has_bias", False) #: Input dimension self.in_size = in_size #: Output dimension self.out_size = out_size #: Number of residual blocks in the network self.nb_blocks = nb_blocks #: Dimension of the latent space after encoding self.dim_hidden = 2 * nb_features #: Input encoding network self.embedding = EnhancedFeatureNet( in_size=in_size, nb_features=nb_features, **kwargs ) #: Linear layer to compute `U` self.embedding_1 = nn.Linear(self.dim_hidden, self.dim_hidden) #: Linear layer to compute `V` self.embedding_2 = nn.Linear(self.dim_hidden, self.dim_hidden) #: Main activation function self.activation = activation_function( activation_type, in_size=in_size, **kwargs ) #: list of `PiranteNet_block` blocks self.blocks = nn.ModuleList( [PirateNetBlock(self.dim_hidden, **kwargs) for _ in range(self.nb_blocks)] ) #: Output layer self.output_layer = nn.Linear( self.dim_hidden, self.out_size, bias=last_layer_has_bias ) #: Final activation function applied to the output self.activation_output = activation_function( activation_output, in_size=in_size, **kwargs )
[docs] def forward( self, inputs: torch.Tensor, with_last_layer: bool = True ) -> torch.Tensor: """Applies the network transformation to the inputs. Args: inputs: Input of the network with_last_layer: If `True`, applies the output layer and final activation, default is `True` Returns: Output of the network after transformation """ inputs = self.embedding(inputs) U = self.activation(self.embedding_1(inputs)) V = self.activation(self.embedding_2(inputs)) for i in range(self.nb_blocks): inputs = self.blocks[i](inputs, U, V) if with_last_layer: inputs = self.activation_output(self.output_layer(inputs)) return inputs