Source code for scimba_torch.neural_nets.structure_preserving_nets.ode_splitted_layer

"""ODE-based splitted layer for invertible networks."""

from __future__ import annotations

from abc import abstractmethod

import torch

from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
    InvertibleLayer,
)


[docs] class ODESplittedLayer(InvertibleLayer): """Abstract class for ODE-based invertible layers that operate on split tensors. This layer processes one part of a split tensor while being conditioned on: - The conditional input (mu) - Other parts of the split tensor (specified by indices) Args: size: dimension of the input part to process conditional_size: dimension of the conditional input data split_index: index of the part this layer processes (0-based) other_indices: list of indices of other parts to use as conditioning **kwargs: other arguments for the invertible layer """ def __init__( self, size: int, conditional_size: int, split_index: int, other_indices: list[int], **kwargs, ): super().__init__(size, conditional_size, **kwargs) self.split_index = split_index self.other_indices = other_indices
[docs] @abstractmethod def forward( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], with_last_layer: bool = True, ) -> torch.Tensor: """Forward pass of the ODE-based layer. Args: y: the input tensor part to transform, shape `(batch_size, size)` mu: the conditional input tensor, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts used for conditioning with_last_layer: whether to use the last layer of the network Returns: The transformed tensor of shape `(batch_size, size)` """
[docs] @abstractmethod def backward( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], with_last_layer: bool = True, ) -> torch.Tensor: """Backward pass (inverse) of the ODE-based layer. Args: y: the input tensor part to transform, shape `(batch_size, size)` mu: the conditional input tensor, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts used for conditioning with_last_layer: whether to use the last layer of the network Returns: The inverse-transformed tensor of shape `(batch_size, size)` """
[docs] @abstractmethod def log_abs_det_jacobian( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], ) -> torch.Tensor: """Computes the log absolute determinant of the Jacobian. Args: y: the input tensor part, shape `(batch_size, size)` mu: the conditional input tensor, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts used for conditioning Returns: The log absolute determinant as a tensor of shape `(batch_size,)` """