Source code for scimba_torch.neural_nets.structure_preserving_nets.coupling_symplectic_layers

"""Symplectic coupling layers for structure-preserving neural networks."""

import torch

from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
    InvertibleLayer,
)
from scimba_torch.neural_nets.structure_preserving_nets.separated_symplectic_layers import (  # noqa: E501
    ActivationSymplecticLayer,
    GradPotentialSymplecticLayer,
    LinearSymplecticLayer,
    PeriodicGradPotentialSymplecticLayer,
)
from scimba_torch.neural_nets.structure_preserving_nets.split_layer import (
    SplittingLayer,
)


[docs] class GSymplecticLayer(InvertibleLayer): """A G-symplectic coupling layer. Applies coupling transformations where one variable is kept fixed while transforming the other, preserving the symplectic structure. Args: size: Total dimension of the state space (will be split in half). conditional_size: Dimension of the conditional input. width: Width of the internal layers. **kwargs: Additional keyword arguments. """ def __init__(self, size: int, conditional_size: int, width: int, **kwargs): InvertibleLayer.__init__(self, size, conditional_size) self.n = size // 2 self.width = width # Networks for transforming p based on q self.G1layer = GradPotentialSymplecticLayer( size=self.n, conditional_size=conditional_size, width=width, **kwargs, ) self.G2layer = GradPotentialSymplecticLayer( size=self.n, conditional_size=conditional_size, width=width, **kwargs, ) self.split_layer = SplittingLayer(size=size, conditional_size=0, num_splits=2)
[docs] def forward(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Applies the forward coupling transformation. Args: y: Input tensor of shape `(batch_size, size)` containing (p, q). mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Transformed tensor of shape `(batch_size, size)`. """ q, p = self.split_layer.split(y) q_1, p_1 = self.G1layer.forward(q, p, mu) p_2, q_2 = self.G2layer.forward(p_1, q_1, mu) y = self.split_layer.unsplit([q_2, p_2]) return y
[docs] def backward(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Applies the inverse coupling transformation. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Inverse transformed tensor of shape `(batch_size, size)`. """ # Inverse in reverse order # y contains [q_2, p_2] from forward q_2, p_2 = self.split_layer.split(y) # Inverse G2: (p_2, q_2) -> (p_1, q_1) p_1, q_1 = self.G2layer.backward(p_2, q_2, mu) # Inverse G1: (q_1, p_1) -> (q, p) q, p = self.G1layer.backward(q_1, p_1, mu) # Reconstruct y as [q, p] y = self.split_layer.unsplit([q, p]) return y
[docs] def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the log absolute determinant of the Jacobian. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Log determinant of shape `(batch_size,)`. """ # G-symplectic transformations preserve volume, determinant = 1 return torch.zeros(y.shape[0], device=y.device)
[docs] def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the absolute determinant of the Jacobian. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Determinant of shape `(batch_size,)`. """ return torch.ones(y.shape[0], device=y.device)
[docs] class PeriodicGSymplecticLayer(InvertibleLayer): """A periodic G-symplectic coupling layer. Applies coupling transformations where one variable is kept fixed while transforming the other, preserving the symplectic structure. The first variable is periodized, the second one is not. Args: size: Total dimension of the state space (will be split in half). conditional_size: Dimension of the conditional input. width: Width of the internal layers. period: The period of the potential. **kwargs: Additional keyword arguments. """ def __init__( self, size: int, conditional_size: int, width: int, period: torch.Tensor, **kwargs, ): InvertibleLayer.__init__(self, size, conditional_size) self.n = size // 2 self.width = width self.period = period # Networks for transforming p based on q self.G1layer = PeriodicGradPotentialSymplecticLayer( size=self.n, conditional_size=conditional_size, width=width, period=self.period, **kwargs, ) self.G2layer = GradPotentialSymplecticLayer( size=self.n, conditional_size=conditional_size, width=width, **kwargs, ) self.split_layer = SplittingLayer(size=size, conditional_size=0, num_splits=2)
[docs] def forward(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Applies the forward coupling transformation. Args: y: Input tensor of shape `(batch_size, size)` containing (p, q). mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Transformed tensor of shape `(batch_size, size)`. """ q, p = self.split_layer.split(y) q_1, p_1 = self.G1layer.forward(q, p, mu) p_2, q_2 = self.G2layer.forward(p_1, q_1, mu) y = self.split_layer.unsplit([q_2, p_2]) return y
[docs] def backward(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Applies the inverse coupling transformation. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Inverse transformed tensor of shape `(batch_size, size)`. """ # Inverse in reverse order # y contains [q_2, p_2] from forward q_2, p_2 = self.split_layer.split(y) # Inverse G2: (p_2, q_2) -> (p_1, q_1) p_1, q_1 = self.G2layer.backward(p_2, q_2, mu) # q_1 = q_1 % self.period # Inverse G1: (q_1, p_1) -> (q, p) q, p = self.G1layer.backward(q_1, p_1, mu) # Reconstruct y as [q, p] y = self.split_layer.unsplit([q, p]) return y
[docs] def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the log absolute determinant of the Jacobian. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Log determinant of shape `(batch_size,)`. """ # G-symplectic transformations preserve volume, determinant = 1 return torch.zeros(y.shape[0], device=y.device)
[docs] def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the absolute determinant of the Jacobian. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Determinant of shape `(batch_size,)`. """ return torch.ones(y.shape[0], device=y.device)
[docs] class LASymplecticLayer(InvertibleLayer): """A LA-symplectic coupling layer. Applies Linear-Activation coupling transformations preserving symplectic structure. Args: size: Total dimension of the state space (will be split in half). conditional_size: Dimension of the conditional input. width: Width of the internal layers. **kwargs: Additional keyword arguments. """ def __init__(self, size: int, conditional_size: int, width: int, **kwargs): InvertibleLayer.__init__(self, size, conditional_size) self.n = size // 2 self.width = width # Networks for transforming p based on q self.Linear1layer = LinearSymplecticLayer( size=self.n, conditional_size=conditional_size, **kwargs, ) self.Activation1layer = ActivationSymplecticLayer( size=self.n, conditional_size=conditional_size, **kwargs, ) self.Linear2layer = LinearSymplecticLayer( size=self.n, conditional_size=conditional_size, **kwargs, ) self.Activation2layer = ActivationSymplecticLayer( size=self.n, conditional_size=conditional_size, **kwargs, ) self.split_layer = SplittingLayer(size=size, num_splits=2)
[docs] def forward(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Applies the forward coupling transformation. Args: y: Input tensor of shape `(batch_size, size)` containing (p, q). mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Transformed tensor of shape `(batch_size, size)`. """ p, q = self.split_layer.split(y) y = self.Linear1layer.forward(p, q, mu) p, q = self.split_layer.split(y) y = self.Activation1layer.forward(q, p, mu) p, q = self.split_layer.split(y) y = self.Linear2layer.forward(q, p, mu) p, q = self.split_layer.split(y) y = self.Activation2layer.forward(p, q, mu) return y
[docs] def backward(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Applies the inverse coupling transformation. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Inverse transformed tensor of shape `(batch_size, size)`. """ # Inverse in reverse order p, q = self.split_layer.split(y) y = self.Activation2layer.backward(p, q, mu) p, q = self.split_layer.split(y) y = self.Linear2layer.backward(q, p, mu) p, q = self.split_layer.split(y) y = self.Activation1layer.backward(q, p, mu) p, q = self.split_layer.split(y) y = self.Linear1layer.backward(p, q, mu) return y
[docs] def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the log absolute determinant of the Jacobian. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Log determinant of shape `(batch_size,)`. """ # G-symplectic transformations preserve volume, determinant = 1 return torch.zeros(y.shape[0], device=y.device)
[docs] def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor: """Computes the absolute determinant of the Jacobian. Args: y: Input tensor of shape `(batch_size, size)`. mu: Conditional input of shape `(batch_size, conditional_size)`. Returns: Determinant of shape `(batch_size,)`. """ return torch.ones(y.shape[0], device=y.device)