scimba_torch.neural_nets.coordinates_based_nets.pirate_net¶
PirateNet architecture implementation.
Classes
|
A PirateNet neural network implementation. |
|
Implements a block of the PirateNet. |
- class PirateNetBlock(dim=1, **kwargs)[source]¶
Bases:
ScimbaModuleImplements a block of the PirateNet.
Each block applies three linear transformations with activation to compute weighting matrices U and V, then updates the input x by combining these matrices using a residual scheme.
- Parameters:
dim (
int) – Input and output dimension of the block, default is 1**kwargs – Additional parameters for block configuration
- W_f¶
Linear layer for the f_l transformation
- W_g¶
Linear layer for the g_l transformation
- W_h¶
Linear layer for the h_l transformation
- alpha¶
Trainable parameter for mixing the old and new value of x
- activation¶
Activation function used in the block
- class PirateNet(in_size=1, out_size=1, nb_features=1, nb_blocks=1, **kwargs)[source]¶
Bases:
ScimbaModuleA PirateNet neural network implementation.
- Parameters:
in_size (
int) – Input dimension, default is 1out_size (
int) – Output dimension, default is 1nb_features (
int) – Number of features used for encoding, default is 1nb_blocks (
int) – Number of stacked PiranteNet_block layers, default is 1**kwargs – Additional parameters for network configuration
- in_size¶
Input dimension
- out_size¶
Output dimension
- nb_blocks¶
Number of residual blocks in the network
Dimension of the latent space after encoding
- embedding¶
Input encoding network
- embedding_1¶
Linear layer to compute U
- embedding_2¶
Linear layer to compute V
- activation¶
Main activation function
- blocks¶
list of PiranteNet_block blocks
- output_layer¶
Output layer
- activation_output¶
Final activation function applied to the output
- forward(inputs, with_last_layer=True)[source]¶
Applies the network transformation to the inputs.
- Parameters:
inputs (
Tensor) – Input of the networkwith_last_layer (
bool) – If True, applies the output layer and final activation, default is True
- Return type:
Tensor- Returns:
Output of the network after transformation