"""Affine and constant flow layers for invertible networks."""
from __future__ import annotations
import torch
from torch import nn
from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP
from scimba_torch.neural_nets.structure_preserving_nets.ode_splitted_layer import (
ODESplittedLayer,
)
[docs]
class ConstantFlowLayer(ODESplittedLayer):
"""Constant flow layer for NICE-style transformations.
This layer creates len(other_indices) neural networks that progressively
incorporate information from the other split parts.
Args:
size: dimension of the input part to process (split_sizes[split_index])
conditional_size: dimension of the conditional input data
split_sizes: list of sizes for all split parts
split_index: index of the part this layer processes
other_indices: list of indices of other parts to use as conditioning
**kwargs: other arguments for the neural networks
"""
def __init__(
self,
size: int,
conditional_size: int,
split_sizes: list[int],
split_index: int,
other_indices: list[int],
**kwargs,
):
super().__init__(size, conditional_size, split_index, other_indices, **kwargs)
self.split_sizes = split_sizes
self.networks = nn.ModuleList()
self.net_type = kwargs.get("net_type", GenericMLP)
# Create len(other_indices) networks
for i, other_idx in enumerate(other_indices):
# Input size accumulates: split_sizes[split_index] + conditional_size
# + sum of previous other parts
input_size = split_sizes[split_index] + conditional_size
for j in range(i):
input_size += split_sizes[other_indices[j]]
# Output size is the size of the current other part
output_size = split_sizes[other_idx]
# Create the network
network = self.net_type(in_size=input_size, out_size=output_size, **kwargs)
self.networks.append(network)
[docs]
def forward(
self,
y: torch.Tensor,
mu: torch.Tensor,
other_parts: list[torch.Tensor],
with_last_layer: bool = True,
) -> torch.Tensor:
"""Forward pass: x_a stays unchanged, others are shifted.
For K=3: y_a = x_a, y_b = x_b + t(x_a, mu), y_c = x_c + t(x_a, y_b, mu)
Args:
y: the input tensor part (x_a), shape `(batch_size, size)`
mu: the conditional input, shape `(batch_size, conditional_size)`
other_parts: list of other tensor parts [x_b, x_c, ...]
with_last_layer: whether to use the last layer
Returns:
Recombined tensor with y and transformed other_parts
"""
# x_a = y stays unchanged
# We only modify other_parts in-place
# Transform each other part progressively
for i, network in enumerate(self.networks):
# Build input: [x_a, mu, transformed_parts[0:i]]
network_input = [y, mu]
for j in range(i):
network_input.append(other_parts[j])
network_input = torch.cat(network_input, dim=-1)
# Apply transformation: y_i = x_i + t_theta(inputs)
translation = network(network_input)
other_parts[i] = other_parts[i] + translation
# Recombine all parts in correct order
all_parts = [None] * (len(other_parts) + 1)
all_parts[self.split_index] = y
for i, idx in enumerate(self.other_indices):
all_parts[idx] = other_parts[i]
return torch.cat(all_parts, dim=-1)
[docs]
def backward(
self,
y: torch.Tensor,
mu: torch.Tensor,
other_parts: list[torch.Tensor],
with_last_layer: bool = True,
) -> torch.Tensor:
"""Backward pass: inverse transformation.
For K=3: x_c = y_c - t(y_a, y_b, mu), x_b = y_b - t(y_a, mu), x_a = y_a
Args:
y: the input tensor part (y_a), shape `(batch_size, size)`
mu: the conditional input, shape `(batch_size, conditional_size)`
other_parts: list of transformed parts [y_b, y_c, ...]
with_last_layer: whether to use the last layer
Returns:
Recombined tensor with y and inverse-transformed other_parts
"""
# Inverse transformation: apply networks in reverse order
for i in reversed(range(len(self.networks))):
network = self.networks[i]
# Build input: [y_a, mu, other_parts[0:i]]
network_input = [y, mu]
for j in range(i):
network_input.append(other_parts[j])
network_input = torch.cat(network_input, dim=-1)
# Apply inverse transformation: x_i = y_i - t_theta(inputs)
translation = network(network_input)
other_parts[i] = other_parts[i] - translation
# Recombine all parts in correct order
all_parts = [None] * (len(other_parts) + 1)
all_parts[self.split_index] = y
for i, idx in enumerate(self.other_indices):
all_parts[idx] = other_parts[i]
return torch.cat(all_parts, dim=-1)
[docs]
def log_abs_det_jacobian(
self,
y: torch.Tensor,
mu: torch.Tensor,
other_parts: list[torch.Tensor],
) -> torch.Tensor:
"""Log absolute determinant of Jacobian.
For NICE-style constant flow, the Jacobian is triangular with 1s on diagonal,
so det = 1 and log|det| = 0.
Args:
y: the input tensor part, shape `(batch_size, size)`
mu: the conditional input, shape `(batch_size, conditional_size)`
other_parts: list of other tensor parts
Returns:
Zeros tensor of shape `(batch_size,)`
"""
return torch.zeros(y.shape[0], device=y.device, dtype=y.dtype)
[docs]
def abs_det_jacobian(
self,
y: torch.Tensor,
mu: torch.Tensor,
other_parts: list[torch.Tensor],
) -> torch.Tensor:
"""Absolute determinant of Jacobian.
For NICE-style constant flow, det = 1.
Args:
y: the input tensor part, shape `(batch_size, size)`
mu: the conditional input, shape `(batch_size, conditional_size)`
other_parts: list of other tensor parts
Returns:
Ones tensor of shape `(batch_size,)`
"""
return torch.ones(y.shape[0], device=y.device, dtype=y.dtype)
[docs]
class AffineFlowLayer(ODESplittedLayer):
"""Affine flow layer for RealNVP-style transformations.
This layer applies affine transformations: y = exp(t) ⊙ x + s
where ⊙ is element-wise multiplication, t is log-scale and s is translation.
Args:
size: dimension of the input part to process (split_sizes[split_index])
conditional_size: dimension of the conditional input data
split_sizes: list of sizes for all split parts
split_index: index of the part this layer processes
other_indices: list of indices of other parts to use as conditioning
**kwargs: other arguments for the neural networks
"""
def __init__(
self,
size: int,
conditional_size: int,
split_sizes: list[int],
split_index: int,
other_indices: list[int],
**kwargs,
):
super().__init__(size, conditional_size, split_index, other_indices, **kwargs)
self.split_sizes = split_sizes
self.t_networks = nn.ModuleList() # Networks for log-scale (t)
self.s_networks = nn.ModuleList() # Networks for translation (s)
self.net_type = kwargs.get("net_type", GenericMLP)
# Create len(other_indices) pairs of networks (one for t, one for s)
for i, other_idx in enumerate(other_indices):
# Input size accumulates: split_sizes[split_index] + conditional_size
# + sum of previous other parts
input_size = split_sizes[split_index] + conditional_size
for j in range(i):
input_size += split_sizes[other_indices[j]]
# Output size is the size of the current other part
output_size = split_sizes[other_idx]
# Create network for log-scale (t)
t_network = self.net_type(
in_size=input_size, out_size=output_size, **kwargs
)
self.t_networks.append(t_network)
# Create network for translation (s)
s_network = self.net_type(
in_size=input_size, out_size=output_size, **kwargs
)
self.s_networks.append(s_network)
[docs]
def forward(
self,
y: torch.Tensor,
mu: torch.Tensor,
other_parts: list[torch.Tensor],
with_last_layer: bool = True,
) -> torch.Tensor:
"""Forward pass: x_a stays unchanged, others are affinely transformed.
For K=3:
- y_a = x_a
- y_b = exp(t(x_a, mu)) ⊙ x_b + s(x_a, mu)
- y_c = exp(t(x_a, y_b, mu)) ⊙ x_c + s(x_a, y_b, mu)
Args:
y: the input tensor part (x_a), shape `(batch_size, size)`
mu: the conditional input, shape `(batch_size, conditional_size)`
other_parts: list of other tensor parts [x_b, x_c, ...]
with_last_layer: whether to use the last layer
Returns:
Recombined tensor with y and transformed other_parts
"""
# x_a = y stays unchanged
# Transform each other part progressively
for i in range(len(self.t_networks)):
# Build input: [x_a, mu, transformed_parts[0:i]]
network_input = [y, mu]
for j in range(i):
network_input.append(other_parts[j])
network_input = torch.cat(network_input, dim=-1)
# Compute log-scale (t) and translation (s) from separate networks
t = self.t_networks[i](network_input)
s = self.s_networks[i](network_input)
# Apply affine transformation: y_i = exp(t) ⊙ x_i + s
other_parts[i] = torch.exp(t) * other_parts[i] + s
# Recombine all parts in correct order
all_parts = [None] * (len(other_parts) + 1)
all_parts[self.split_index] = y
for i, idx in enumerate(self.other_indices):
all_parts[idx] = other_parts[i]
return torch.cat(all_parts, dim=-1)
[docs]
def backward(
self,
y: torch.Tensor,
mu: torch.Tensor,
other_parts: list[torch.Tensor],
with_last_layer: bool = True,
) -> torch.Tensor:
"""Backward pass: inverse affine transformation.
For K=3:
- x_c = (y_c - s(y_a, y_b, mu)) * exp(-t(y_a, y_b, mu))
- x_b = (y_b - s(y_a, mu)) * exp(-t(y_a, mu))
- x_a = y_a
Args:
y: the input tensor part (y_a), shape `(batch_size, size)`
mu: the conditional input, shape `(batch_size, conditional_size)`
other_parts: list of transformed parts [y_b, y_c, ...]
with_last_layer: whether to use the last layer
Returns:
Recombined tensor with y and inverse-transformed other_parts
"""
# Inverse transformation: apply networks in reverse order
for i in reversed(range(len(self.t_networks))):
# Build input: [y_a, mu, other_parts[0:i]]
network_input = [y, mu]
for j in range(i):
network_input.append(other_parts[j])
network_input = torch.cat(network_input, dim=-1)
# Compute log-scale (t) and translation (s) from separate networks
t = self.t_networks[i](network_input)
s = self.s_networks[i](network_input)
# Apply inverse affine transformation: x_i = (y_i - s) * exp(-t)
other_parts[i] = (other_parts[i] - s) * torch.exp(-t)
# Recombine all parts in correct order
all_parts = [None] * (len(other_parts) + 1)
all_parts[self.split_index] = y
for i, idx in enumerate(self.other_indices):
all_parts[idx] = other_parts[i]
return torch.cat(all_parts, dim=-1)
[docs]
def log_abs_det_jacobian(
self,
y: torch.Tensor,
mu: torch.Tensor,
other_parts: list[torch.Tensor],
) -> torch.Tensor:
"""Log absolute determinant of Jacobian.
For affine transformation y = exp(t) ⊙ x + s:
log|det(J)| = sum(t_i) for each network
Args:
y: the input tensor part, shape `(batch_size, size)`
mu: the conditional input, shape `(batch_size, conditional_size)`
other_parts: list of other tensor parts
Returns:
Log determinant tensor of shape `(batch_size,)`
"""
log_det = torch.zeros(y.shape[0], device=y.device, dtype=y.dtype)
# Accumulate log determinants from each transformation
for i in range(len(self.t_networks)):
# Build input: [y, mu, other_parts[0:i]]
network_input = [y, mu]
for j in range(i):
network_input.append(other_parts[j])
network_input = torch.cat(network_input, dim=-1)
# Compute log-scale (t) from t_network
t = self.t_networks[i](network_input)
# log|det| = sum of all components of t
log_det += t.sum(dim=-1)
return log_det
[docs]
def abs_det_jacobian(
self,
y: torch.Tensor,
mu: torch.Tensor,
other_parts: list[torch.Tensor],
) -> torch.Tensor:
"""Absolute determinant of Jacobian.
For affine transformation: det = exp(sum(t_i)) for each network
Args:
y: the input tensor part, shape `(batch_size, size)`
mu: the conditional input, shape `(batch_size, conditional_size)`
other_parts: list of other tensor parts
Returns:
Determinant tensor of shape `(batch_size,)`
"""
return torch.exp(self.log_abs_det_jacobian(y, mu, other_parts))