"""An invertible neural network made of RealNVP layers."""
from __future__ import annotations
from abc import ABC, abstractmethod
import torch
from torch import nn
from scimba_torch.neural_nets.coordinates_based_nets.scimba_module import ScimbaModule
[docs]
class InvertibleLayer(ScimbaModule, ABC):
"""An abstract class for an invertible layer.
Args:
size: dimension of the input data
conditional_size: dimension of the conditional input data
**kwargs: other arguments for the layer
"""
def __init__(self, size: int, conditional_size: int, **kwargs):
ScimbaModule.__init__(self, size + conditional_size, size, **kwargs)
self.conditional_size = conditional_size
self.size = size
[docs]
@abstractmethod
def backward(self, inputs: torch.Tensor, with_last_layer: bool = True):
"""Abstract method for the backward pass of the invertible layer.
Args:
inputs: the input tensor
with_last_layer: whether to use the last layer of the network or not
(default: True)
"""
[docs]
@abstractmethod
def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the log absolute value of the determinant of the Jacobian.
This method must be implemented by all subclasses. It is the primary
method for computing the Jacobian determinant as it is numerically stable.
Args:
y: the input tensor of shape `(batch_size, size)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The log absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
[docs]
def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the absolute value of the determinant of the Jacobian.
Default implementation uses exp(log_abs_det_jacobian). Subclasses can
override this method if they have a more efficient direct computation.
Args:
y: the input tensor of shape `(batch_size, size)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
return torch.exp(self.log_abs_det_jacobian(y, mu))
[docs]
class InvertibleNet(ScimbaModule):
"""An invertible neural network composed of multiple invertible layers.
Args:
size: dimension of the input data
conditional_size: dimension of the conditional input data
layers_list: list of invertible layers to compose
**kwargs: other arguments for the invertible layers
"""
def __init__(
self,
size: int,
conditional_size: int,
layers_list: list[InvertibleLayer] = None,
**kwargs,
):
super().__init__(in_size=size + conditional_size, out_size=size)
self.size = size
self.conditional_size = conditional_size
self.nb_layers = len(layers_list) if layers_list is not None else 0
self.layers = nn.ModuleList(layers_list if layers_list is not None else [])
[docs]
def forward(
self, inputs: torch.Tensor, with_last_layer: bool = True
) -> torch.Tensor:
"""Applies the forward pass of the invertible network.
Args:
inputs: the input tensor of shape `(batch_size, dim + p_dim)`.
with_last_layer: whether to use the last layer of the network or not
(default: True)
Returns:
The output tensor of shape `(batch_size, dim + p_dim)` after applying all
layers.
"""
y = inputs[..., : self.size]
mu = inputs[..., self.size :]
for layer in self.layers:
y = layer.forward(y, mu)
return y
[docs]
def backward(
self, inputs: torch.Tensor, with_last_layer: bool = True
) -> torch.Tensor:
"""Applies the backward pass of the invertible network.
Args:
inputs: the input tensor of shape `(batch_size, dim + p_dim)`.
with_last_layer: whether to use the last layer of the network or not.
Returns:
The output tensor of shape `(batch_size, dim + p_dim)` after applying all
layers in reverse.
"""
y = inputs[..., : self.size]
mu = inputs[..., self.size :]
for layer in reversed(self.layers):
y = layer.backward(y, mu)
return y
[docs]
def log_abs_det_jacobian(self, inputs: torch.Tensor) -> torch.Tensor:
"""Computes the log absolute value of the determinant of the Jacobian.
This method is numerically more stable and is commonly used in
log-probability computations for normalizing flows.
Args:
inputs: the input tensor of shape `(batch_size, dim + p_dim)`.
Returns:
The log absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
y = inputs[..., : self.size]
mu = inputs[..., self.size :]
log_det_jacobian = torch.zeros(y.shape[0], device=y.device)
for layer in self.layers:
ldj = layer.log_abs_det_jacobian(y, mu)
log_det_jacobian += ldj
y = layer.forward(y, mu)
return log_det_jacobian
[docs]
def abs_det_jacobian(self, inputs: torch.Tensor) -> torch.Tensor:
"""Computes the absolute value of the determinant of the Jacobian.
This method is useful for change of variables in integrals.
Uses the layer's abs_det_jacobian method if overridden, otherwise
falls back to exp(log_abs_det_jacobian).
Args:
inputs: the input tensor of shape `(batch_size, dim + p_dim)`.
Returns:
The absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
y = inputs[..., : self.size]
mu = inputs[..., self.size :]
det_jacobian = torch.ones(y.shape[0], device=y.device)
for layer in self.layers:
det = layer.abs_det_jacobian(y, mu)
det_jacobian *= det
y = layer.forward(y, mu)
return det_jacobian