Source code for scimba_torch.neural_nets.structure_preserving_nets.affine_ode_layers

"""Affine and constant flow layers for invertible networks."""

from __future__ import annotations

import torch
from torch import nn

from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP
from scimba_torch.neural_nets.structure_preserving_nets.ode_splitted_layer import (
    ODESplittedLayer,
)


[docs] class ConstantFlowLayer(ODESplittedLayer): """Constant flow layer for NICE-style transformations. This layer creates len(other_indices) neural networks that progressively incorporate information from the other split parts. Args: size: dimension of the input part to process (split_sizes[split_index]) conditional_size: dimension of the conditional input data split_sizes: list of sizes for all split parts split_index: index of the part this layer processes other_indices: list of indices of other parts to use as conditioning **kwargs: other arguments for the neural networks """ def __init__( self, size: int, conditional_size: int, split_sizes: list[int], split_index: int, other_indices: list[int], **kwargs, ): super().__init__(size, conditional_size, split_index, other_indices, **kwargs) self.split_sizes = split_sizes self.networks = nn.ModuleList() self.net_type = kwargs.get("net_type", GenericMLP) # Create len(other_indices) networks for i, other_idx in enumerate(other_indices): # Input size accumulates: split_sizes[split_index] + conditional_size # + sum of previous other parts input_size = split_sizes[split_index] + conditional_size for j in range(i): input_size += split_sizes[other_indices[j]] # Output size is the size of the current other part output_size = split_sizes[other_idx] # Create the network network = self.net_type(in_size=input_size, out_size=output_size, **kwargs) self.networks.append(network)
[docs] def forward( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], with_last_layer: bool = True, ) -> torch.Tensor: """Forward pass: x_a stays unchanged, others are shifted. For K=3: y_a = x_a, y_b = x_b + t(x_a, mu), y_c = x_c + t(x_a, y_b, mu) Args: y: the input tensor part (x_a), shape `(batch_size, size)` mu: the conditional input, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts [x_b, x_c, ...] with_last_layer: whether to use the last layer Returns: Recombined tensor with y and transformed other_parts """ # x_a = y stays unchanged # We only modify other_parts in-place # Transform each other part progressively for i, network in enumerate(self.networks): # Build input: [x_a, mu, transformed_parts[0:i]] network_input = [y, mu] for j in range(i): network_input.append(other_parts[j]) network_input = torch.cat(network_input, dim=-1) # Apply transformation: y_i = x_i + t_theta(inputs) translation = network(network_input) other_parts[i] = other_parts[i] + translation # Recombine all parts in correct order all_parts = [None] * (len(other_parts) + 1) all_parts[self.split_index] = y for i, idx in enumerate(self.other_indices): all_parts[idx] = other_parts[i] return torch.cat(all_parts, dim=-1)
[docs] def backward( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], with_last_layer: bool = True, ) -> torch.Tensor: """Backward pass: inverse transformation. For K=3: x_c = y_c - t(y_a, y_b, mu), x_b = y_b - t(y_a, mu), x_a = y_a Args: y: the input tensor part (y_a), shape `(batch_size, size)` mu: the conditional input, shape `(batch_size, conditional_size)` other_parts: list of transformed parts [y_b, y_c, ...] with_last_layer: whether to use the last layer Returns: Recombined tensor with y and inverse-transformed other_parts """ # Inverse transformation: apply networks in reverse order for i in reversed(range(len(self.networks))): network = self.networks[i] # Build input: [y_a, mu, other_parts[0:i]] network_input = [y, mu] for j in range(i): network_input.append(other_parts[j]) network_input = torch.cat(network_input, dim=-1) # Apply inverse transformation: x_i = y_i - t_theta(inputs) translation = network(network_input) other_parts[i] = other_parts[i] - translation # Recombine all parts in correct order all_parts = [None] * (len(other_parts) + 1) all_parts[self.split_index] = y for i, idx in enumerate(self.other_indices): all_parts[idx] = other_parts[i] return torch.cat(all_parts, dim=-1)
[docs] def log_abs_det_jacobian( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], ) -> torch.Tensor: """Log absolute determinant of Jacobian. For NICE-style constant flow, the Jacobian is triangular with 1s on diagonal, so det = 1 and log|det| = 0. Args: y: the input tensor part, shape `(batch_size, size)` mu: the conditional input, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts Returns: Zeros tensor of shape `(batch_size,)` """ return torch.zeros(y.shape[0], device=y.device, dtype=y.dtype)
[docs] def abs_det_jacobian( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], ) -> torch.Tensor: """Absolute determinant of Jacobian. For NICE-style constant flow, det = 1. Args: y: the input tensor part, shape `(batch_size, size)` mu: the conditional input, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts Returns: Ones tensor of shape `(batch_size,)` """ return torch.ones(y.shape[0], device=y.device, dtype=y.dtype)
[docs] class AffineFlowLayer(ODESplittedLayer): """Affine flow layer for RealNVP-style transformations. This layer applies affine transformations: y = exp(t) ⊙ x + s where ⊙ is element-wise multiplication, t is log-scale and s is translation. Args: size: dimension of the input part to process (split_sizes[split_index]) conditional_size: dimension of the conditional input data split_sizes: list of sizes for all split parts split_index: index of the part this layer processes other_indices: list of indices of other parts to use as conditioning **kwargs: other arguments for the neural networks """ def __init__( self, size: int, conditional_size: int, split_sizes: list[int], split_index: int, other_indices: list[int], **kwargs, ): super().__init__(size, conditional_size, split_index, other_indices, **kwargs) self.split_sizes = split_sizes self.t_networks = nn.ModuleList() # Networks for log-scale (t) self.s_networks = nn.ModuleList() # Networks for translation (s) self.net_type = kwargs.get("net_type", GenericMLP) # Create len(other_indices) pairs of networks (one for t, one for s) for i, other_idx in enumerate(other_indices): # Input size accumulates: split_sizes[split_index] + conditional_size # + sum of previous other parts input_size = split_sizes[split_index] + conditional_size for j in range(i): input_size += split_sizes[other_indices[j]] # Output size is the size of the current other part output_size = split_sizes[other_idx] # Create network for log-scale (t) t_network = self.net_type( in_size=input_size, out_size=output_size, **kwargs ) self.t_networks.append(t_network) # Create network for translation (s) s_network = self.net_type( in_size=input_size, out_size=output_size, **kwargs ) self.s_networks.append(s_network)
[docs] def forward( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], with_last_layer: bool = True, ) -> torch.Tensor: """Forward pass: x_a stays unchanged, others are affinely transformed. For K=3: - y_a = x_a - y_b = exp(t(x_a, mu)) ⊙ x_b + s(x_a, mu) - y_c = exp(t(x_a, y_b, mu)) ⊙ x_c + s(x_a, y_b, mu) Args: y: the input tensor part (x_a), shape `(batch_size, size)` mu: the conditional input, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts [x_b, x_c, ...] with_last_layer: whether to use the last layer Returns: Recombined tensor with y and transformed other_parts """ # x_a = y stays unchanged # Transform each other part progressively for i in range(len(self.t_networks)): # Build input: [x_a, mu, transformed_parts[0:i]] network_input = [y, mu] for j in range(i): network_input.append(other_parts[j]) network_input = torch.cat(network_input, dim=-1) # Compute log-scale (t) and translation (s) from separate networks t = self.t_networks[i](network_input) s = self.s_networks[i](network_input) # Apply affine transformation: y_i = exp(t) ⊙ x_i + s other_parts[i] = torch.exp(t) * other_parts[i] + s # Recombine all parts in correct order all_parts = [None] * (len(other_parts) + 1) all_parts[self.split_index] = y for i, idx in enumerate(self.other_indices): all_parts[idx] = other_parts[i] return torch.cat(all_parts, dim=-1)
[docs] def backward( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], with_last_layer: bool = True, ) -> torch.Tensor: """Backward pass: inverse affine transformation. For K=3: - x_c = (y_c - s(y_a, y_b, mu)) * exp(-t(y_a, y_b, mu)) - x_b = (y_b - s(y_a, mu)) * exp(-t(y_a, mu)) - x_a = y_a Args: y: the input tensor part (y_a), shape `(batch_size, size)` mu: the conditional input, shape `(batch_size, conditional_size)` other_parts: list of transformed parts [y_b, y_c, ...] with_last_layer: whether to use the last layer Returns: Recombined tensor with y and inverse-transformed other_parts """ # Inverse transformation: apply networks in reverse order for i in reversed(range(len(self.t_networks))): # Build input: [y_a, mu, other_parts[0:i]] network_input = [y, mu] for j in range(i): network_input.append(other_parts[j]) network_input = torch.cat(network_input, dim=-1) # Compute log-scale (t) and translation (s) from separate networks t = self.t_networks[i](network_input) s = self.s_networks[i](network_input) # Apply inverse affine transformation: x_i = (y_i - s) * exp(-t) other_parts[i] = (other_parts[i] - s) * torch.exp(-t) # Recombine all parts in correct order all_parts = [None] * (len(other_parts) + 1) all_parts[self.split_index] = y for i, idx in enumerate(self.other_indices): all_parts[idx] = other_parts[i] return torch.cat(all_parts, dim=-1)
[docs] def log_abs_det_jacobian( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], ) -> torch.Tensor: """Log absolute determinant of Jacobian. For affine transformation y = exp(t) ⊙ x + s: log|det(J)| = sum(t_i) for each network Args: y: the input tensor part, shape `(batch_size, size)` mu: the conditional input, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts Returns: Log determinant tensor of shape `(batch_size,)` """ log_det = torch.zeros(y.shape[0], device=y.device, dtype=y.dtype) # Accumulate log determinants from each transformation for i in range(len(self.t_networks)): # Build input: [y, mu, other_parts[0:i]] network_input = [y, mu] for j in range(i): network_input.append(other_parts[j]) network_input = torch.cat(network_input, dim=-1) # Compute log-scale (t) from t_network t = self.t_networks[i](network_input) # log|det| = sum of all components of t log_det += t.sum(dim=-1) return log_det
[docs] def abs_det_jacobian( self, y: torch.Tensor, mu: torch.Tensor, other_parts: list[torch.Tensor], ) -> torch.Tensor: """Absolute determinant of Jacobian. For affine transformation: det = exp(sum(t_i)) for each network Args: y: the input tensor part, shape `(batch_size, size)` mu: the conditional input, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts Returns: Determinant tensor of shape `(batch_size,)` """ return torch.exp(self.log_abs_det_jacobian(y, mu, other_parts))