Source code for scimba_torch.neural_nets.structure_preserving_nets.separated_symplectic_layers

"""Defines the SympNet class for symplectic neural networks."""

import torch
from torch import nn

from scimba_torch.neural_nets.coordinates_based_nets.activation import (
    activation_function,
)
from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
    InvertibleLayer,
)


[docs] class LinearSymplecticLayer(InvertibleLayer): """Combines a linear transformation on two input tensors :code:`y` and :code:`p`. Applies an activation function, scales the result based on :code:`p`, and returns a matrix product of the transformed tensors. The module is used to model potential gradients in neural network architectures, especially in problems involving structured data. Args: size: Total dimension of the state space (will be split into p and q). conditional_size: Dimension of the conditional input tensor. **kwargs: Additional keyword arguments. The activation function type can be passed as a keyword argument (e.g., "tanh", "relu"). """ def __init__(self, size: int, conditional_size: int, **kwargs): InvertibleLayer.__init__(self, size, conditional_size) self.n = size // 2 #: Linear transformation for the `y` input tensor. self.linear_q: nn.Linear = nn.Linear(self.n, self.n, bias=False) #: Linear transformation for the `p` input tensor. self.linear_mu: nn.Linear = nn.Linear(conditional_size, self.n)
[docs] def forward( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the forward pass. This method combines the transformations of the input tensors, applies an activation function, scales the result, and returns the matrix product. Args: p: The momentum tensor. q: The position tensor. mu: The conditional input tensor. Returns: The output tensor after applying the transformation and scaling. """ p = p + self.linear_q(q) + q @ self.linear_q.weight + self.linear_mu(mu) return torch.cat((p, q), dim=-1)
[docs] def backward( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the backward (inverse) pass. This method inverts the transformation applied in the forward pass. Args: p: The momentum tensor of dimension `(batch_size, n)`. q: The position tensor of dimension `(batch_size, n)`. mu: The conditional input tensor of dimension `(batch_size, conditional_size)`. Returns: The concatenated tensor of the original `(p, q)`. """ # Forward: p_new = p + W @ q + q @ W^T + linear_mu(mu) # Backward: p_old = p_new - W @ q - q @ W^T - linear_mu(mu) p = p - self.linear_q(q) - q @ self.linear_q.weight - self.linear_mu(mu) return torch.cat((p, q), dim=-1)
[docs] def log_abs_det_jacobian( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the log absolute value of the determinant of the Jacobian. Args: p: the momentum tensor of shape `(batch_size, n)`. q: the position tensor of shape `(batch_size, n)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The log absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ return torch.zeros(p.shape[0], device=p.device)
[docs] def abs_det_jacobian( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the absolute value of the determinant of the Jacobian. Args: p: the momentum tensor of shape `(batch_size, n)`. q: the position tensor of shape `(batch_size, n)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ return torch.ones(p.shape[0], device=p.device)
[docs] class ActivationSymplecticLayer(InvertibleLayer): """Combines a linear transformation on two input tensors :code:`y` and :code:`p`. Applies an activation function, scales the result based on :code:`p`, and returns a matrix product of the transformed tensors. The module is used to model potential gradients in neural network architectures, especially in problems involving structured data. Args: size: Total dimension of the state space (will be split into p and q). conditional_size: Dimension of the conditional input tensor. **kwargs: Additional keyword arguments. The activation function type can be passed as a keyword argument (e.g., "tanh", "relu"). """ def __init__(self, size: int, conditional_size: int, **kwargs): InvertibleLayer.__init__(self, size, conditional_size) self.n = size // 2 #: Linear transformation for the `y` input tensor. self.linear_a: nn.Linear = nn.Linear(self.n, 1, bias=False) self.act = activation_function(kwargs.get("activation_type", "tanh"), **kwargs)
[docs] def forward( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the forward pass. This method combines the transformations of the input tensors, applies an activation function, scales the result, and returns the matrix product. Args: p: The momentum tensor. q: The position tensor. mu: The conditional input tensor. Returns: The output tensor after applying the transformation and scaling. """ p = p + self.linear_a.weight.squeeze() * self.act(q) return torch.cat((p, q), dim=-1)
[docs] def backward( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the backward (inverse) pass. This method inverts the transformation applied in the forward pass. Args: p: The momentum tensor of dimension `(batch_size, n)`. q: The position tensor of dimension `(batch_size, n)`. mu: The conditional input tensor of dimension `(batch_size, conditional_size)`. Returns: The concatenated tensor of the original `(p, q)`. """ # Forward: p_new = p + diag(linear_a.weight) * act(q) # Backward: p_old = p_new - diag(linear_a.weight) * act(q) p = p - self.linear_a.weight.squeeze() * self.act(q) return torch.cat((p, q), dim=-1)
[docs] def log_abs_det_jacobian( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the log absolute value of the determinant of the Jacobian. Args: p: the momentum tensor of shape `(batch_size, n)`. q: the position tensor of shape `(batch_size, n)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The log absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ # The Jacobian is [[I, diag(a)*act'(q)], [0, I]] which has determinant 1 return torch.zeros(p.shape[0], device=p.device)
[docs] def abs_det_jacobian( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the absolute value of the determinant of the Jacobian. Args: p: the momentum tensor of shape `(batch_size, n)`. q: the position tensor of shape `(batch_size, n)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ # Symplectic transformations preserve volume, determinant = 1 return torch.ones(p.shape[0], device=p.device)
[docs] class GradPotentialSymplecticLayer(InvertibleLayer): """Combines a linear transformation on two input tensors :code:`y` and :code:`p`. Applies an activation function, scales the result based on :code:`p`, and returns a matrix product of the transformed tensors. The module is used to model potential gradients in neural network architectures, especially in problems involving structured data. Args: size: Total dimension of the state space. conditional_size: Dimension of the conditional input tensor. width: Width of the internal layers (i.e., the number of units in the hidden layers). **kwargs: Additional keyword arguments. The activation function type can be passed as a keyword argument (e.g., "tanh", "relu"). """ def __init__(self, size: int, conditional_size: int, width: int, **kwargs): InvertibleLayer.__init__(self, size, conditional_size) self.width = width #: Linear transformation for the `y` input tensor. self.linear_q: nn.Linear = nn.Linear(size, width, bias=False) #: Linear transformation for the `p` input tensor. self.linear_mu: nn.Linear = nn.Linear(conditional_size, width) #: Activation function type (e.g., 'tanh') applied to the sum of the linear #: transformations. self.activation_type: str = kwargs.get("activation", "tanh") #: Linear scaling transformation for the `p` tensor. self.scaling: nn.Linear = nn.Linear(conditional_size, width) #: Activation function applied to the sum of the linear transformations. self.activation = activation_function(self.activation_type, **kwargs)
[docs] def forward( self, q: torch.Tensor, p: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the forward pass. This method combines the transformations of the input tensors, applies an activation function, scales the result, and returns the matrix product. Args: q: The position tensor. p: The momentum tensor. mu: The conditional input tensor. Returns: The output tensor after applying the transformation and scaling. """ p_int = self.activation(self.linear_q(q) + self.linear_mu(mu)) p = p + (self.scaling(mu) * p_int) @ self.linear_q.weight return q, p
[docs] def backward( self, q: torch.Tensor, p: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the backward (inverse) pass. This method inverts the transformation applied in the forward pass. Args: q: The position tensor of dimension `(batch_size, n)`. p: The momentum tensor of dimension `(batch_size, n)`. mu: The conditional input tensor of dimension `(batch_size, conditional_size)`. Returns: The concatenated tensor of the original `(p, q)`. """ # Forward: p_new = p + (scaling(mu) * activation(...)) @ linear_q.weight # Backward: p_old = p_new - (scaling(mu) * activation(...)) @ linear_q.weight p_int = self.activation(self.linear_q(q) + self.linear_mu(mu)) p = p - (self.scaling(mu) * p_int) @ self.linear_q.weight return q, p
[docs] def log_abs_det_jacobian( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the log absolute value of the determinant of the Jacobian. Args: p: the momentum tensor of shape `(batch_size, n)`. q: the position tensor of shape `(batch_size, n)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The log absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ # The Jacobian is [[I, grad_function], [0, I]] which has determinant 1 return torch.zeros(p.shape[0], device=p.device)
[docs] def abs_det_jacobian( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the absolute value of the determinant of the Jacobian. Args: p: the momentum tensor of shape `(batch_size, n)`. q: the position tensor of shape `(batch_size, n)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ # Symplectic transformations preserve volume, determinant = 1 return torch.ones(p.shape[0], device=p.device)
[docs] class PeriodicGradPotentialSymplecticLayer(InvertibleLayer): """Combines a linear transformation on two input tensors :code:`y` and :code:`p`. Applies an activation function, scales the result based on :code:`p`, and returns a matrix product of the transformed tensors. The module is used to model periodic potential gradients in neural network architectures, especially in problems involving structured data. Args: size: Total dimension of the state space. conditional_size: Dimension of the conditional input tensor. width: Width of the internal layers (i.e., the number of units in the hidden layers). period: The period of the potential. **kwargs: Additional keyword arguments. The activation function type can be passed as a keyword argument (e.g., "tanh", "relu"). """ def __init__( self, size: int, conditional_size: int, width: int, period: torch.Tensor, **kwargs, ): InvertibleLayer.__init__(self, size, conditional_size) self.width = width #: Linear transformation for the `y` input tensor. self.linear_q1: nn.Linear = nn.Linear(size, width, bias=False) self.linear_q2: nn.Linear = nn.Linear(size, width, bias=False) self.b1 = nn.Parameter(torch.zeros(width)) self.b2 = nn.Parameter(torch.zeros(width)) #: Linear transformation for the `p` input tensor. self.linear_mu: nn.Linear = nn.Linear(conditional_size, width) #: Activation function type (e.g., 'tanh') applied to the sum of the linear #: transformations. self.activation_type: str = kwargs.get("activation", "tanh") #: Linear scaling transformation for the `p` tensor. self.scaling: nn.Linear = nn.Linear(conditional_size, width) #: Activation function applied to the sum of the linear transformations. self.activation = activation_function(self.activation_type, **kwargs) self.L = period
[docs] def forward( self, q: torch.Tensor, p: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the forward pass. This method combines the transformations of the input tensors, applies an activation function, scales the result, and returns the matrix product. Args: q: The position tensor. p: The momentum tensor. mu: The conditional input tensor. Returns: The output tensor after applying the transformation and scaling. """ # p_int = self.activation(self.linear_q1(q) + self.linear_mu(mu)) # p = p + (self.scaling(mu) * p_int) @ self.linear_q.weight phase = 2 * torch.pi * q / self.L cos_phase = torch.cos(phase) sin_phase = torch.sin(phase) z1 = self.linear_q1(cos_phase) + self.b1 z2 = self.linear_q2(sin_phase) + self.b2 # (batch_size, width) p_int1 = self.activation(z1 + self.linear_mu(mu)) p_int2 = self.activation(z2 + self.linear_mu(mu)) # result = p_int2 @ self.linear_q2.weight * cos_phase # result -= p_int1 @ self.linear_q1.weight* sin_phase result = ((self.scaling(mu) * p_int2) @ self.linear_q2.weight) * cos_phase result -= ((self.scaling(mu) * p_int1) @ self.linear_q1.weight) * sin_phase p = p + result return q, p
[docs] def backward( self, q: torch.Tensor, p: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the backward (inverse) pass. This method inverts the transformation applied in the forward pass. Args: q: The position tensor of dimension `(batch_size, n)`. p: The momentum tensor of dimension `(batch_size, n)`. mu: The conditional input tensor of dimension `(batch_size, conditional_size)`. Returns: The concatenated tensor of the original `(p, q)`. """ # Forward: p_new = p + (scaling(mu) * activation(...)) @ linear_q.weight # Backward: p_old = p_new - (scaling(mu) * activation(...)) @ linear_q.weight # p_int = self.activation(self.linear_q(q) + self.linear_mu(mu)) # p = p - (self.scaling(mu) * p_int) @ self.linear_q.weight phase = 2 * torch.pi * q / self.L cos_phase = torch.cos(phase) sin_phase = torch.sin(phase) z1 = self.linear_q1(cos_phase) + self.b1 z2 = self.linear_q2(sin_phase) + self.b2 # (batch_size, width) p_int1 = self.activation(z1 + self.linear_mu(mu)) p_int2 = self.activation(z2 + self.linear_mu(mu)) result = ((self.scaling(mu) * p_int2) @ self.linear_q2.weight) * cos_phase result -= ((self.scaling(mu) * p_int1) @ self.linear_q1.weight) * sin_phase p = p - result return q, p
[docs] def log_abs_det_jacobian( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the log absolute value of the determinant of the Jacobian. Args: p: the momentum tensor of shape `(batch_size, n)`. q: the position tensor of shape `(batch_size, n)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The log absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ # The Jacobian is [[I, grad_function], [0, I]] which has determinant 1 return torch.zeros(p.shape[0], device=p.device)
[docs] def abs_det_jacobian( self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor ) -> torch.Tensor: """Computes the absolute value of the determinant of the Jacobian. Args: p: the momentum tensor of shape `(batch_size, n)`. q: the position tensor of shape `(batch_size, n)`. mu: the conditional input tensor of shape `(batch_size, conditional_size)`. Returns: The absolute determinant of the Jacobian as a tensor of shape `(batch_size,)`. """ # Symplectic transformations preserve volume, determinant = 1 return torch.ones(p.shape[0], device=p.device)