"""Nilpotent layers for invertible neural networks."""
import torch
from torch import nn
from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
InvertibleLayer,
)
[docs]
class PSymplecticLayer(InvertibleLayer):
"""A nilpotent symplectic layer.
This layer implements a polynomial-based nilpotent symplectic transformation:
z = y + α'(W*y + W_mu*mu) * J @ W @ y
where α(x) = sum_(i=0)^(deg) a_i * x^i is a polynomial, α' is its derivative,
and J is the symplectic matrix.
Args:
size: Total dimension of the state space.
conditional_size: Dimension of the conditional input.
deg: Degree of the polynomial α.
**kwargs: Additional keyword arguments.
"""
def __init__(self, size: int, conditional_size: int, deg: int, **kwargs):
InvertibleLayer.__init__(self, size, conditional_size)
self.size = size
self.n = size // 2
self.conditional_size = conditional_size
self.deg = deg
#: Linear transformation W for y
self.W: nn.Linear = nn.Linear(size, size, bias=False)
#: Linear transformation W_mu for mu
self.W_mu: nn.Linear = nn.Linear(conditional_size, size, bias=False)
#: Polynomial coefficients a_i for i=0 to deg (initialized small)
self.poly_coeffs: nn.Parameter = nn.Parameter(torch.randn(deg + 1) * 0.1)
# Create symplectic matrix J = [[0, I], [-I, 0]]
# Note: J is antisymmetric (J^T = -J) and satisfies J^T @ J = -I
J = torch.zeros(size, size)
J[: self.n, self.n :] = torch.eye(self.n)
J[self.n :, : self.n] = -torch.eye(self.n)
self.register_buffer("J", J)
def _eval_poly_derivative(self, x: torch.Tensor) -> torch.Tensor:
"""Evaluates the derivative of the polynomial α at x.
α'(x) = sum_(i=1)^(deg) i * a_i * x^(i-1)
Args:
x: Input tensor of shape (batch_size, size).
Returns:
Derivative values of shape (batch_size, size).
"""
result = torch.zeros_like(x)
for i in range(1, self.deg + 1):
result = result + i * self.poly_coeffs[i] * torch.pow(x, i - 1)
return result
[docs]
def forward(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the forward pass.
z = y + α'(W*y + W_mu*mu) * J @ W @ y
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
Transformed tensor of shape `(batch_size, size)`.
"""
# Compute W*y + W_mu*mu
linear_comb = self.W(y) + self.W_mu(mu)
# Evaluate α'(W*y + W_mu*mu)
alpha_prime = self._eval_poly_derivative(linear_comb)
# Compute J @ W @ y
Wy = self.W(y)
JWy = self.J @ Wy.T # J @ Wy, transposed for batch processing
JWy = JWy.T # Transpose back to (batch_size, size)
# Final transformation: z = y + α'(...) * J @ W @ y
z = y + alpha_prime * JWy
return z
[docs]
def backward(
self, z: torch.Tensor, mu: torch.Tensor, max_iter: int = 2000
) -> torch.Tensor:
"""Computes the inverse transformation (not straightforward for general case).
For nilpotent transformations, the inverse can be computed iteratively
or using the nilpotent structure.
Args:
z: Output tensor of shape `(batch_size, size)`.
mu: Conditional input tensor of shape `(batch_size, conditional_size)`.
max_iter: Maximum number of iterations for fixed point iteration.
Returns:
Original input tensor y of shape `(batch_size, size)`.
Raises:
ValueError: when iterations exceed max_iter without convergence.
"""
# For a nilpotent transformation, we can use iterative inversion
# with damped fixed point iteration for stability
y = z.clone()
damping = 0.5 # Damping factor for stability
for i in range(max_iter): # Fixed point iteration with more steps
linear_comb = self.W(y) + self.W_mu(mu)
alpha_prime = self._eval_poly_derivative(linear_comb)
Wy = self.W(y)
JWy = self.J @ Wy.T
JWy = JWy.T
# Damped update: y_new = (1-damping)*y + damping*(z - alpha_prime * JWy)
y_new = (1 - damping) * y + damping * (z - alpha_prime * JWy)
# Check convergence
if torch.allclose(y_new, y, atol=1e-6):
break
y = y_new
if i >= max_iter - 1:
raise ValueError(
"backward of NilpotentSymplecticLayer did not converge within max_iter"
)
return y
[docs]
def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the log absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
Log determinant of shape `(batch_size,)`.
"""
# For symplectic transformations, the determinant is 1
return torch.zeros(y.shape[0], device=y.device)
[docs]
def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
Determinant of shape `(batch_size,)`.
"""
return torch.ones(y.shape[0], device=y.device)