scimba_torch.neural_nets.structure_preserving_nets.split_layer

Split layer for invertible networks.

Classes

SplittingLayer(size, conditional_size[, ...])

A layer that splits a tensor into multiple parts along the last dimension.

class SplittingLayer(size, conditional_size, num_splits=2)[source]

Bases: object

A layer that splits a tensor into multiple parts along the last dimension.

This layer divides the input tensor into 2, 3, or 4 parts. If the size is not evenly divisible, the remainder is added to the last part.

Parameters:
  • size (int) – total dimension of the input data

  • conditional_size (int) – dimension of the conditional input data

  • num_splits (int) – number of splits (2, 3, or 4)

Raises:

ValueError – If num_splits is not 2, 3, or 4.

Example

# >>> layer = SplitLayer(size=10, conditional_size=0, num_splits=3) # >>> # Will split into sizes [3, 3, 4] (remainder in last)

split(y)[source]

Splits the input tensor into multiple parts.

Parameters:

y (Tensor) – the input tensor of shape (batch_size, size)

Return type:

list[Tensor]

Returns:

A list of tensors split along the last dimension

unsplit(inputs)[source]

Concatenates the split tensors back together.

Parameters:

inputs (list[Tensor]) – a list of tensors to concatenate

Return type:

Tensor

Returns:

The concatenated tensor of shape (batch_size, size)