Source code for scimba_torch.neural_nets.structure_preserving_nets.nilpotent_symplectic_layer

"""Nilpotent layers for invertible neural networks."""

import torch
from torch import nn

from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
    InvertibleLayer,
)


[docs] class PSymplecticLayer(InvertibleLayer): """A nilpotent symplectic layer. This layer implements a polynomial-based nilpotent symplectic transformation: z = y + α'(W*y + W_mu*mu) * J @ W @ y where α(x) = sum_(i=0)^(deg) a_i * x^i is a polynomial, α' is its derivative, and J is the symplectic matrix. Args: size: Total dimension of the state space. conditional_size: Dimension of the conditional input. deg: Degree of the polynomial α. **kwargs: Additional keyword arguments. """ def __init__(self, size: int, conditional_size: int, deg: int, **kwargs): InvertibleLayer.__init__(self, size, conditional_size) self.size = size self.n = size // 2 self.conditional_size = conditional_size self.deg = deg #: Linear transformation W for y self.W: nn.Linear = nn.Linear(size, size, bias=False) #: Linear transformation W_mu for mu self.W_mu: nn.Linear = nn.Linear(conditional_size, size, bias=False) #: Polynomial coefficients a_i for i=0 to deg (initialized small) self.poly_coeffs: nn.Parameter = nn.Parameter(torch.randn(deg + 1) * 0.1) # Create symplectic matrix J = [[0, I], [-I, 0]] # Note: J is antisymmetric (J^T = -J) and satisfies J^T @ J = -I J = torch.zeros(size, size) J[: self.n, self.n :] = torch.eye(self.n) J[self.n :, : self.n] = -torch.eye(self.n) self.register_buffer("J", J) def _eval_poly_derivative(self, x: torch.Tensor) -> torch.Tensor: """Evaluates the derivative of the polynomial α at x. α'(x) = sum_(i=1)^(deg) i * a_i * x^(i-1) Args: x: Input tensor of shape (batch_size, size). Returns: Derivative values of shape (batch_size, size). """ result = torch.zeros_like(x) for i in range(1, self.deg + 1): result = result + i * self.poly_coeffs[i] * torch.pow(x, i - 1) return result
[docs] def forward(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the forward pass. z = y + α'(W*y + W_mu*mu) * J @ W @ y Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input tensor of shape `(batch_size, conditional_size)`. Returns: Transformed tensor of shape `(batch_size, size)`. """ # Compute W*y + W_mu*mu linear_comb = self.W(y) + self.W_mu(mu) # Evaluate α'(W*y + W_mu*mu) alpha_prime = self._eval_poly_derivative(linear_comb) # Compute J @ W @ y Wy = self.W(y) JWy = self.J @ Wy.T # J @ Wy, transposed for batch processing JWy = JWy.T # Transpose back to (batch_size, size) # Final transformation: z = y + α'(...) * J @ W @ y z = y + alpha_prime * JWy return z
[docs] def backward( self, z: torch.Tensor, mu: torch.Tensor, max_iter: int = 2000 ) -> torch.Tensor: """Computes the inverse transformation (not straightforward for general case). For nilpotent transformations, the inverse can be computed iteratively or using the nilpotent structure. Args: z: Output tensor of shape `(batch_size, size)`. mu: Conditional input tensor of shape `(batch_size, conditional_size)`. max_iter: Maximum number of iterations for fixed point iteration. Returns: Original input tensor y of shape `(batch_size, size)`. Raises: ValueError: when iterations exceed max_iter without convergence. """ # For a nilpotent transformation, we can use iterative inversion # with damped fixed point iteration for stability y = z.clone() damping = 0.5 # Damping factor for stability for i in range(max_iter): # Fixed point iteration with more steps linear_comb = self.W(y) + self.W_mu(mu) alpha_prime = self._eval_poly_derivative(linear_comb) Wy = self.W(y) JWy = self.J @ Wy.T JWy = JWy.T # Damped update: y_new = (1-damping)*y + damping*(z - alpha_prime * JWy) y_new = (1 - damping) * y + damping * (z - alpha_prime * JWy) # Check convergence if torch.allclose(y_new, y, atol=1e-6): break y = y_new if i >= max_iter - 1: raise ValueError( "backward of NilpotentSymplecticLayer did not converge within max_iter" ) return y
[docs] def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the log absolute determinant of the Jacobian. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input tensor of shape `(batch_size, conditional_size)`. Returns: Log determinant of shape `(batch_size,)`. """ # For symplectic transformations, the determinant is 1 return torch.zeros(y.shape[0], device=y.device)
[docs] def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the absolute determinant of the Jacobian. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input tensor of shape `(batch_size, conditional_size)`. Returns: Determinant of shape `(batch_size,)`. """ return torch.ones(y.shape[0], device=y.device)