Source code for scimba_torch.neural_nets.structure_preserving_nets.split_layer

"""Split layer for invertible networks."""

from __future__ import annotations

import torch


[docs] class SplittingLayer: """A layer that splits a tensor into multiple parts along the last dimension. This layer divides the input tensor into 2, 3, or 4 parts. If the size is not evenly divisible, the remainder is added to the last part. Args: size: total dimension of the input data conditional_size: dimension of the conditional input data num_splits: number of splits (2, 3, or 4) Raises: ValueError: If num_splits is not 2, 3, or 4. Example: # >>> layer = SplitLayer(size=10, conditional_size=0, num_splits=3) # >>> # Will split into sizes [3, 3, 4] (remainder in last) """ def __init__( self, size: int, conditional_size: int, num_splits: int = 2, ): if num_splits not in [2, 3, 4]: raise ValueError("num_splits must be 2, 3, or 4") self.size = size self.conditional_size = conditional_size self.num_splits = num_splits # Calculate split sizes base_size = size // num_splits remainder = size % num_splits # Distribute sizes: all parts get base_size, last part gets remainder self.split_sizes = [base_size] * num_splits self.split_sizes[-1] += remainder
[docs] def split(self, y: torch.Tensor) -> list[torch.Tensor]: """Splits the input tensor into multiple parts. Args: y: the input tensor of shape `(batch_size, size)` Returns: A list of tensors split along the last dimension """ return torch.split(y, self.split_sizes, dim=-1)
[docs] def unsplit(self, inputs: list[torch.Tensor]) -> torch.Tensor: """Concatenates the split tensors back together. Args: inputs: a list of tensors to concatenate Returns: The concatenated tensor of shape `(batch_size, size)` """ return torch.cat(inputs, dim=-1)