scimba_torch.neural_nets.structure_preserving_nets.ode_splitted_layer

ODE-based splitted layer for invertible networks.

Classes

ODESplittedLayer(size, conditional_size, ...)

Abstract class for ODE-based invertible layers that operate on split tensors.

class ODESplittedLayer(size, conditional_size, split_index, other_indices, **kwargs)[source]

Bases: InvertibleLayer

Abstract class for ODE-based invertible layers that operate on split tensors.

This layer processes one part of a split tensor while being conditioned on: - The conditional input (mu) - Other parts of the split tensor (specified by indices)

Parameters:
  • size (int) – dimension of the input part to process

  • conditional_size (int) – dimension of the conditional input data

  • split_index (int) – index of the part this layer processes (0-based)

  • other_indices (list[int]) – list of indices of other parts to use as conditioning

  • **kwargs – other arguments for the invertible layer

abstract forward(y, mu, other_parts, with_last_layer=True)[source]

Forward pass of the ODE-based layer.

Parameters:
  • y (Tensor) – the input tensor part to transform, shape (batch_size, size)

  • mu (Tensor) – the conditional input tensor, shape (batch_size, conditional_size)

  • other_parts (list[Tensor]) – list of other tensor parts used for conditioning

  • with_last_layer (bool) – whether to use the last layer of the network

Return type:

Tensor

Returns:

The transformed tensor of shape (batch_size, size)

abstract backward(y, mu, other_parts, with_last_layer=True)[source]

Backward pass (inverse) of the ODE-based layer.

Parameters:
  • y (Tensor) – the input tensor part to transform, shape (batch_size, size)

  • mu (Tensor) – the conditional input tensor, shape (batch_size, conditional_size)

  • other_parts (list[Tensor]) – list of other tensor parts used for conditioning

  • with_last_layer (bool) – whether to use the last layer of the network

Return type:

Tensor

Returns:

The inverse-transformed tensor of shape (batch_size, size)

abstract log_abs_det_jacobian(y, mu, other_parts)[source]

Computes the log absolute determinant of the Jacobian.

Parameters:
  • y (Tensor) – the input tensor part, shape (batch_size, size)

  • mu (Tensor) – the conditional input tensor, shape (batch_size, conditional_size)

  • other_parts (list[Tensor]) – list of other tensor parts used for conditioning

Return type:

Tensor

Returns:

The log absolute determinant as a tensor of shape (batch_size,)