Source code for scimba_torch.neural_nets.structure_preserving_nets.split_layer
"""Split layer for invertible networks."""from__future__importannotationsimporttorch
[docs]classSplittingLayer:"""A layer that splits a tensor into multiple parts along the last dimension. This layer divides the input tensor into 2, 3, or 4 parts. If the size is not evenly divisible, the remainder is added to the last part. Args: size: total dimension of the input data conditional_size: dimension of the conditional input data num_splits: number of splits (2, 3, or 4) Raises: ValueError: If num_splits is not 2, 3, or 4. Example: # >>> layer = SplitLayer(size=10, conditional_size=0, num_splits=3) # >>> # Will split into sizes [3, 3, 4] (remainder in last) """def__init__(self,size:int,conditional_size:int,num_splits:int=2,):ifnum_splitsnotin[2,3,4]:raiseValueError("num_splits must be 2, 3, or 4")self.size=sizeself.conditional_size=conditional_sizeself.num_splits=num_splits# Calculate split sizesbase_size=size//num_splitsremainder=size%num_splits# Distribute sizes: all parts get base_size, last part gets remainderself.split_sizes=[base_size]*num_splitsself.split_sizes[-1]+=remainder
[docs]defsplit(self,y:torch.Tensor)->list[torch.Tensor]:"""Splits the input tensor into multiple parts. Args: y: the input tensor of shape `(batch_size, size)` Returns: A list of tensors split along the last dimension """returntorch.split(y,self.split_sizes,dim=-1)
[docs]defunsplit(self,inputs:list[torch.Tensor])->torch.Tensor:"""Concatenates the split tensors back together. Args: inputs: a list of tensors to concatenate Returns: The concatenated tensor of shape `(batch_size, size)` """returntorch.cat(inputs,dim=-1)