Source code for scimba_torch.neural_nets.structure_preserving_nets.coupling_layers

"""Coupling layer for invertible networks."""

from __future__ import annotations

import random

import torch

from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP
from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
    InvertibleLayer,
)
from scimba_torch.neural_nets.structure_preserving_nets.ode_splitted_layer import (
    ODESplittedLayer,
)
from scimba_torch.neural_nets.structure_preserving_nets.split_layer import (
    SplittingLayer,
)


[docs] class CouplingLayer(InvertibleLayer): """A coupling layer that splits input and applies ODE-based transformations. This layer: 1. Splits the input y into K parts (K = 2, 3, or 4) 2. Creates a random permutation of indices [0, 1, ..., K-1] 3. Applies K ODE_splitted_layer transformations, each processing one part while being conditioned on the other parts Args: size: total dimension of the input data conditional_size: dimension of the conditional input data num_splits: number of splits (K = 2, 3, or 4) ode_layer_type: class type of ODE_splitted_layer to use seed: random seed for the permutation (optional, for reproducibility) **kwargs: other arguments for the invertible layer Raises: ValueError: If num_splits is not 2, 3, or 4. Example: >>> layer = CouplingLayer( ... size=10, ... conditional_size=5, ... num_splits=3, ... ode_layer_type=MyODELayer, ... ode_layer_kwargs={'hidden_dim': 64} ... ) """ def __init__( self, size: int, conditional_size: int, num_splits: int, ode_layer_type: type[ODESplittedLayer], seed: int = None, **kwargs, ): super().__init__(size, conditional_size, **kwargs) if num_splits not in [2, 3, 4]: raise ValueError("num_splits must be 2, 3, or 4") self.num_splits = num_splits # Create splitting layer self.split_layer = SplittingLayer( size=size, conditional_size=conditional_size, num_splits=num_splits ) self.net_type = kwargs.get("net_type", GenericMLP) # Create random permutation of indices indices = list(range(num_splits)) rng = random.Random(seed) rng.shuffle(indices) self.permutation = indices # Create ODE layers for each split part self.ode_layers = torch.nn.ModuleList() for i, perm_idx in enumerate(self.permutation): # Other indices for conditioning (all except current) other_indices = [idx for idx in self.permutation if idx != perm_idx] layer = ode_layer_type( size=self.split_layer.split_sizes[perm_idx], conditional_size=conditional_size, split_sizes=self.split_layer.split_sizes, split_index=perm_idx, other_indices=other_indices, **kwargs, ) self.ode_layers.append(layer)
[docs] def forward( self, y: torch.Tensor, mu: torch.Tensor, with_last_layer: bool = True ) -> torch.Tensor: """Forward pass through the coupling layer. Args: y: the input tensor of shape `(batch_size, size)` mu: the conditional input tensor of shape `(batch_size, conditional_size)` with_last_layer: whether to use the last layer Returns: The transformed tensor of shape `(batch_size, size)` """ # Apply ODE layers according to permutation for i, ode_layer in enumerate(self.ode_layers): # Split the current state split_parts = list(self.split_layer.split(y)) perm_idx = self.permutation[i] other_indices = ode_layer.other_indices # Get other parts for conditioning other_parts = [split_parts[idx] for idx in other_indices] # Transform: ode_layer returns the full recombined vector y = ode_layer.forward( split_parts[perm_idx], mu, other_parts, with_last_layer ) return y
[docs] def backward( self, y: torch.Tensor, mu: torch.Tensor, with_last_layer: bool = True ) -> torch.Tensor: """Backward pass (inverse) through the coupling layer. Args: y: the input tensor of shape `(batch_size, size)` mu: the conditional input tensor of shape `(batch_size, conditional_size)` with_last_layer: whether to use the last layer Returns: The inverse-transformed tensor of shape `(batch_size, size)` """ # Apply ODE layers in reverse order for i in reversed(range(len(self.ode_layers))): # Split the current state split_parts = list(self.split_layer.split(y)) ode_layer = self.ode_layers[i] perm_idx = self.permutation[i] other_indices = ode_layer.other_indices # Get other parts for conditioning other_parts = [split_parts[idx] for idx in other_indices] # Inverse transform: ode_layer returns the full recombined vector y = ode_layer.backward( split_parts[perm_idx], mu, other_parts, with_last_layer ) return y
[docs] def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the log absolute determinant of the Jacobian. Args: y: the input tensor of shape `(batch_size, size)` mu: the conditional input tensor of shape `(batch_size, conditional_size)` Returns: The log absolute determinant as a tensor of shape `(batch_size,)` """ # Accumulate log determinants log_det = torch.zeros(y.shape[0], device=y.device, dtype=y.dtype) for i, ode_layer in enumerate(self.ode_layers): # Split the current state split_parts = list(self.split_layer.split(y)) perm_idx = self.permutation[i] other_indices = ode_layer.other_indices # Get other parts for conditioning other_parts = [split_parts[idx] for idx in other_indices] # Accumulate log determinant log_det += ode_layer.log_abs_det_jacobian( split_parts[perm_idx], mu, other_parts ) # Transform for next iteration: ode_layer returns full recombined vector y = ode_layer.forward(split_parts[perm_idx], mu, other_parts) return log_det
[docs] def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the absolute determinant of the Jacobian. This is the product of the determinants from each ODE layer. Args: y: the input tensor of shape `(batch_size, size)` mu: the conditional input tensor of shape `(batch_size, conditional_size)` Returns: The absolute determinant as a tensor of shape `(batch_size,)` """ # Multiply determinants det = torch.ones(y.shape[0], device=y.device, dtype=y.dtype) for i, ode_layer in enumerate(self.ode_layers): # Split the current state split_parts = list(self.split_layer.split(y)) perm_idx = self.permutation[i] other_indices = ode_layer.other_indices # Get other parts for conditioning other_parts = [split_parts[idx] for idx in other_indices] # Multiply determinant det *= ode_layer.abs_det_jacobian(split_parts[perm_idx], mu, other_parts) # Transform for next iteration: ode_layer returns full recombined vector y = ode_layer.forward(split_parts[perm_idx], mu, other_parts) return det