scimba_torch.neural_nets.structure_preserving_nets.nilpotent_symplectic_layer¶
Nilpotent layers for invertible neural networks.
Classes
|
A nilpotent symplectic layer. |
- class PSymplecticLayer(size, conditional_size, deg, **kwargs)[source]¶
Bases:
InvertibleLayerA nilpotent symplectic layer.
This layer implements a polynomial-based nilpotent symplectic transformation: z = y + α’(W*y + W_mu*mu) * J @ W @ y
where α(x) = sum_(i=0)^(deg) a_i * x^i is a polynomial, α’ is its derivative, and J is the symplectic matrix.
- Parameters:
size (
int) – Total dimension of the state space.conditional_size (
int) – Dimension of the conditional input.deg (
int) – Degree of the polynomial α.**kwargs – Additional keyword arguments.
- W: nn.Linear¶
Linear transformation W for y
- W_mu: nn.Linear¶
Linear transformation W_mu for mu
- poly_coeffs: nn.Parameter¶
Polynomial coefficients a_i for i=0 to deg (initialized small)
- forward(y, mu)[source]¶
Computes the forward pass.
z = y + α’(W*y + W_mu*mu) * J @ W @ y
- Parameters:
y (
Tensor) – Input tensor of shape (batch_size, size).mu (
Tensor) – Conditional input tensor of shape (batch_size, conditional_size).
- Return type:
Tensor- Returns:
Transformed tensor of shape (batch_size, size).
- backward(z, mu, max_iter=2000)[source]¶
Computes the inverse transformation (not straightforward for general case).
For nilpotent transformations, the inverse can be computed iteratively or using the nilpotent structure.
- Parameters:
z (
Tensor) – Output tensor of shape (batch_size, size).mu (
Tensor) – Conditional input tensor of shape (batch_size, conditional_size).max_iter (
int) – Maximum number of iterations for fixed point iteration.
- Return type:
Tensor- Returns:
Original input tensor y of shape (batch_size, size).
- Raises:
ValueError – when iterations exceed max_iter without convergence.
- log_abs_det_jacobian(y, mu)[source]¶
Computes the log absolute determinant of the Jacobian.
- Parameters:
y (
Tensor) – Input tensor of shape (batch_size, size).mu (
Tensor) – Conditional input tensor of shape (batch_size, conditional_size).
- Return type:
Tensor- Returns:
Log determinant of shape (batch_size,).