scimba_torch.neural_nets.structure_preserving_nets.sympnet

Defines the SympNet class for symplectic neural networks.

Classes

SympLayer(y_dim, p_dim, **kwargs)

A layer of a symplectic neural network.

SympNet(dim, p_dim[, widths])

A symplectic neural network composed of multiple SympLayer layers.

class SympLayer(y_dim, p_dim, **kwargs)[source]

Bases: Module

A layer of a symplectic neural network.

It applies transformations to input tensors x and y based on the GradPotential module.

Parameters:
  • y_dim (int) – Dimension of the input tensor y.

  • p_dim (int) – Dimension of the input tensor p.

  • **kwargs – Additional keyword arguments. Parameters scaling and number can be passed.

grad_potential1: GradPotential

The first GradPotential module used for transformations.

grad_potential2: GradPotential

The second GradPotential module used for transformations.

parameters_scaling: bool

A flag indicating if parameters scaling should be applied.

parameters_scaling_number: int

The index for the scaling parameter.

forward(x, y, p, sign=1)[source]

Performs the forward pass of the symplectic layer.

Applies transformations using GradPotential layers, with optional scaling for the p tensor.

Parameters:
  • x (Tensor) – The input tensor x of shape (batch_size, y_dim).

  • y (Tensor) – The input tensor y of shape (batch_size, y_dim).

  • p (Tensor) – The input tensor p of shape (batch_size, p_dim).

  • sign (int) – The sign used to apply the transformations. Default is 1.

Return type:

tuple[Tensor, Tensor]

Returns:

The output tensors (x, y) after transformations.

class SympNet(dim, p_dim, widths=[20, 20, 20, 20, 20], **kwargs)[source]

Bases: ScimbaModule

A symplectic neural network composed of multiple SympLayer layers.

The network processes input tensors x, y, and p, applying transformations through each layer.

Parameters:
  • dim (int) – The dimension of the state space.

  • p_dim (int) – The dimension of the parameter space.

  • widths (list[int]) – The widths of the SympLayer layers. Default is [20] * 5.

  • **kwargs – Additional keyword arguments for the layers.

layers: nn.ModuleList

list of SympLayer layers that form the network.

forward(inputs, with_last_layer=True)[source]

Applies the forward pass of the symplectic network.

Parameters:
  • inputs (Tensor) – the input tensor of shape (batch_size, dim + p_dim).

  • with_last_layer (bool) – whether to use the last layer of the network or not (default: True)

Returns:

The output tensor of shape (batch_size, dim + p_dim) after applying all layers.

inverse(inputs, with_last_layer=True)[source]

Applies the inverse pass of the symplectic network.

Parameters:
  • inputs (Tensor) – the input tensor of shape (batch_size, dim + p_dim).

  • with_last_layer (bool) – whether to use the last layer of the network or not (default: True)

Returns:

The output tensor of shape (batch_size, dim + p_dim) after applying all layers in reverse order.