"""Coupling layer for invertible networks."""
from __future__ import annotations
import random
import torch
from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP
from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
InvertibleLayer,
)
from scimba_torch.neural_nets.structure_preserving_nets.ode_splitted_layer import (
ODESplittedLayer,
)
from scimba_torch.neural_nets.structure_preserving_nets.split_layer import (
SplittingLayer,
)
[docs]
class CouplingLayer(InvertibleLayer):
"""A coupling layer that splits input and applies ODE-based transformations.
This layer:
1. Splits the input y into K parts (K = 2, 3, or 4)
2. Creates a random permutation of indices [0, 1, ..., K-1]
3. Applies K ODE_splitted_layer transformations, each processing one part
while being conditioned on the other parts
Args:
size: total dimension of the input data
conditional_size: dimension of the conditional input data
num_splits: number of splits (K = 2, 3, or 4)
ode_layer_type: class type of ODE_splitted_layer to use
seed: random seed for the permutation (optional, for reproducibility)
**kwargs: other arguments for the invertible layer
Raises:
ValueError: If num_splits is not 2, 3, or 4.
Example:
>>> layer = CouplingLayer(
... size=10,
... conditional_size=5,
... num_splits=3,
... ode_layer_type=MyODELayer,
... ode_layer_kwargs={'hidden_dim': 64}
... )
"""
def __init__(
self,
size: int,
conditional_size: int,
num_splits: int,
ode_layer_type: type[ODESplittedLayer],
seed: int = None,
**kwargs,
):
super().__init__(size, conditional_size, **kwargs)
if num_splits not in [2, 3, 4]:
raise ValueError("num_splits must be 2, 3, or 4")
self.num_splits = num_splits
# Create splitting layer
self.split_layer = SplittingLayer(
size=size, conditional_size=conditional_size, num_splits=num_splits
)
self.net_type = kwargs.get("net_type", GenericMLP)
# Create random permutation of indices
indices = list(range(num_splits))
rng = random.Random(seed)
rng.shuffle(indices)
self.permutation = indices
# Create ODE layers for each split part
self.ode_layers = torch.nn.ModuleList()
for i, perm_idx in enumerate(self.permutation):
# Other indices for conditioning (all except current)
other_indices = [idx for idx in self.permutation if idx != perm_idx]
layer = ode_layer_type(
size=self.split_layer.split_sizes[perm_idx],
conditional_size=conditional_size,
split_sizes=self.split_layer.split_sizes,
split_index=perm_idx,
other_indices=other_indices,
**kwargs,
)
self.ode_layers.append(layer)
[docs]
def forward(
self, y: torch.Tensor, mu: torch.Tensor, with_last_layer: bool = True
) -> torch.Tensor:
"""Forward pass through the coupling layer.
Args:
y: the input tensor of shape `(batch_size, size)`
mu: the conditional input tensor of shape `(batch_size, conditional_size)`
with_last_layer: whether to use the last layer
Returns:
The transformed tensor of shape `(batch_size, size)`
"""
# Apply ODE layers according to permutation
for i, ode_layer in enumerate(self.ode_layers):
# Split the current state
split_parts = list(self.split_layer.split(y))
perm_idx = self.permutation[i]
other_indices = ode_layer.other_indices
# Get other parts for conditioning
other_parts = [split_parts[idx] for idx in other_indices]
# Transform: ode_layer returns the full recombined vector
y = ode_layer.forward(
split_parts[perm_idx], mu, other_parts, with_last_layer
)
return y
[docs]
def backward(
self, y: torch.Tensor, mu: torch.Tensor, with_last_layer: bool = True
) -> torch.Tensor:
"""Backward pass (inverse) through the coupling layer.
Args:
y: the input tensor of shape `(batch_size, size)`
mu: the conditional input tensor of shape `(batch_size, conditional_size)`
with_last_layer: whether to use the last layer
Returns:
The inverse-transformed tensor of shape `(batch_size, size)`
"""
# Apply ODE layers in reverse order
for i in reversed(range(len(self.ode_layers))):
# Split the current state
split_parts = list(self.split_layer.split(y))
ode_layer = self.ode_layers[i]
perm_idx = self.permutation[i]
other_indices = ode_layer.other_indices
# Get other parts for conditioning
other_parts = [split_parts[idx] for idx in other_indices]
# Inverse transform: ode_layer returns the full recombined vector
y = ode_layer.backward(
split_parts[perm_idx], mu, other_parts, with_last_layer
)
return y
[docs]
def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the log absolute determinant of the Jacobian.
Args:
y: the input tensor of shape `(batch_size, size)`
mu: the conditional input tensor of shape `(batch_size, conditional_size)`
Returns:
The log absolute determinant as a tensor of shape `(batch_size,)`
"""
# Accumulate log determinants
log_det = torch.zeros(y.shape[0], device=y.device, dtype=y.dtype)
for i, ode_layer in enumerate(self.ode_layers):
# Split the current state
split_parts = list(self.split_layer.split(y))
perm_idx = self.permutation[i]
other_indices = ode_layer.other_indices
# Get other parts for conditioning
other_parts = [split_parts[idx] for idx in other_indices]
# Accumulate log determinant
log_det += ode_layer.log_abs_det_jacobian(
split_parts[perm_idx], mu, other_parts
)
# Transform for next iteration: ode_layer returns full recombined vector
y = ode_layer.forward(split_parts[perm_idx], mu, other_parts)
return log_det
[docs]
def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the absolute determinant of the Jacobian.
This is the product of the determinants from each ODE layer.
Args:
y: the input tensor of shape `(batch_size, size)`
mu: the conditional input tensor of shape `(batch_size, conditional_size)`
Returns:
The absolute determinant as a tensor of shape `(batch_size,)`
"""
# Multiply determinants
det = torch.ones(y.shape[0], device=y.device, dtype=y.dtype)
for i, ode_layer in enumerate(self.ode_layers):
# Split the current state
split_parts = list(self.split_layer.split(y))
perm_idx = self.permutation[i]
other_indices = ode_layer.other_indices
# Get other parts for conditioning
other_parts = [split_parts[idx] for idx in other_indices]
# Multiply determinant
det *= ode_layer.abs_det_jacobian(split_parts[perm_idx], mu, other_parts)
# Transform for next iteration: ode_layer returns full recombined vector
y = ode_layer.forward(split_parts[perm_idx], mu, other_parts)
return det