Source code for scimba_torch.neural_nets.structure_preserving_nets.invertible_nn

"""An invertible neural network made of RealNVP layers."""

from __future__ import annotations

from abc import ABC, abstractmethod

import torch
from torch import nn

from scimba_torch.neural_nets.coordinates_based_nets.scimba_module import ScimbaModule


[docs] class InvertibleLayer(ScimbaModule, ABC): """An abstract class for an invertible layer. Args: size: dimension of the input data conditional_size: dimension of the conditional input data **kwargs: other arguments for the layer """ def __init__(self, size: int, conditional_size: int, **kwargs): ScimbaModule.__init__(self, size + conditional_size, size, **kwargs) self.conditional_size = conditional_size self.size = size
[docs] @abstractmethod def backward(self, inputs: torch.Tensor, with_last_layer: bool = True): """Abstract method for the backward pass of the invertible layer. Args: inputs: the input tensor with_last_layer: whether to use the last layer of the network or not (default: True) """
[docs] @abstractmethod def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the log absolute value of the determinant of the Jacobian. This method must be implemented by all subclasses. It is the primary method for computing the Jacobian determinant as it is numerically stable. Args: y: the input tensor of shape `(batch_size, size)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The log absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """
[docs] def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the absolute value of the determinant of the Jacobian. Default implementation uses exp(log_abs_det_jacobian). Subclasses can override this method if they have a more efficient direct computation. Args: y: the input tensor of shape `(batch_size, size)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ return torch.exp(self.log_abs_det_jacobian(y, mu))
[docs] class InvertibleNet(ScimbaModule): """An invertible neural network composed of multiple invertible layers. Args: size: dimension of the input data conditional_size: dimension of the conditional input data layers_list: list of invertible layers to compose **kwargs: other arguments for the invertible layers """ def __init__( self, size: int, conditional_size: int, layers_list: list[InvertibleLayer] = None, **kwargs, ): super().__init__(in_size=size + conditional_size, out_size=size) self.size = size self.conditional_size = conditional_size self.nb_layers = len(layers_list) if layers_list is not None else 0 self.layers = nn.ModuleList(layers_list if layers_list is not None else [])
[docs] def forward( self, inputs: torch.Tensor, with_last_layer: bool = True ) -> torch.Tensor: """Applies the forward pass of the invertible network. Args: inputs: the input tensor of shape `(batch_size, dim + p_dim)`. with_last_layer: whether to use the last layer of the network or not (default: True) Returns: The output tensor of shape `(batch_size, dim + p_dim)` after applying all layers. """ y = inputs[..., : self.size] mu = inputs[..., self.size :] for layer in self.layers: y = layer.forward(y, mu) return y
[docs] def backward( self, inputs: torch.Tensor, with_last_layer: bool = True ) -> torch.Tensor: """Applies the backward pass of the invertible network. Args: inputs: the input tensor of shape `(batch_size, dim + p_dim)`. with_last_layer: whether to use the last layer of the network or not. Returns: The output tensor of shape `(batch_size, dim + p_dim)` after applying all layers in reverse. """ y = inputs[..., : self.size] mu = inputs[..., self.size :] for layer in reversed(self.layers): y = layer.backward(y, mu) return y
[docs] def log_abs_det_jacobian(self, inputs: torch.Tensor) -> torch.Tensor: """Computes the log absolute value of the determinant of the Jacobian. This method is numerically more stable and is commonly used in log-probability computations for normalizing flows. Args: inputs: the input tensor of shape `(batch_size, dim + p_dim)`. Returns: The log absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ y = inputs[..., : self.size] mu = inputs[..., self.size :] log_det_jacobian = torch.zeros(y.shape[0], device=y.device) for layer in self.layers: ldj = layer.log_abs_det_jacobian(y, mu) log_det_jacobian += ldj y = layer.forward(y, mu) return log_det_jacobian
[docs] def abs_det_jacobian(self, inputs: torch.Tensor) -> torch.Tensor: """Computes the absolute value of the determinant of the Jacobian. This method is useful for change of variables in integrals. Uses the layer's abs_det_jacobian method if overridden, otherwise falls back to exp(log_abs_det_jacobian). Args: inputs: the input tensor of shape `(batch_size, dim + p_dim)`. Returns: The absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ y = inputs[..., : self.size] mu = inputs[..., self.size :] det_jacobian = torch.ones(y.shape[0], device=y.device) for layer in self.layers: det = layer.abs_det_jacobian(y, mu) det_jacobian *= det y = layer.forward(y, mu) return det_jacobian