"""Defines the SympNet class for symplectic neural networks."""
import torch
from torch import nn
from scimba_torch.neural_nets.coordinates_based_nets.activation import (
activation_function,
)
from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
InvertibleLayer,
)
[docs]
class LinearSymplecticLayer(InvertibleLayer):
"""Combines a linear transformation on two input tensors :code:`y` and :code:`p`.
Applies an activation function, scales the result based on :code:`p`,
and returns a matrix product of the transformed tensors.
The module is used to model potential gradients in neural network architectures,
especially in problems involving structured data.
Args:
size: Total dimension of the state space (will be split into p and q).
conditional_size: Dimension of the conditional input tensor.
**kwargs: Additional keyword arguments. The activation function type can be
passed as a keyword argument (e.g., "tanh", "relu").
"""
def __init__(self, size: int, conditional_size: int, **kwargs):
InvertibleLayer.__init__(self, size, conditional_size)
self.n = size // 2
#: Linear transformation for the `y` input tensor.
self.linear_q: nn.Linear = nn.Linear(self.n, self.n, bias=False)
#: Linear transformation for the `p` input tensor.
self.linear_mu: nn.Linear = nn.Linear(conditional_size, self.n)
[docs]
def forward(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the forward pass.
This method combines the transformations of the input tensors,
applies an activation function, scales the result,
and returns the matrix product.
Args:
p: The momentum tensor.
q: The position tensor.
mu: The conditional input tensor.
Returns:
The output tensor after applying the transformation and scaling.
"""
p = p + self.linear_q(q) + q @ self.linear_q.weight + self.linear_mu(mu)
return torch.cat((p, q), dim=-1)
[docs]
def backward(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the backward (inverse) pass.
This method inverts the transformation applied in the forward pass.
Args:
p: The momentum tensor of dimension `(batch_size, n)`.
q: The position tensor of dimension `(batch_size, n)`.
mu: The conditional input tensor of dimension
`(batch_size, conditional_size)`.
Returns:
The concatenated tensor of the original `(p, q)`.
"""
# Forward: p_new = p + W @ q + q @ W^T + linear_mu(mu)
# Backward: p_old = p_new - W @ q - q @ W^T - linear_mu(mu)
p = p - self.linear_q(q) - q @ self.linear_q.weight - self.linear_mu(mu)
return torch.cat((p, q), dim=-1)
[docs]
def log_abs_det_jacobian(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the log absolute value of the determinant of the Jacobian.
Args:
p: the momentum tensor of shape `(batch_size, n)`.
q: the position tensor of shape `(batch_size, n)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The log absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
return torch.zeros(p.shape[0], device=p.device)
[docs]
def abs_det_jacobian(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the absolute value of the determinant of the Jacobian.
Args:
p: the momentum tensor of shape `(batch_size, n)`.
q: the position tensor of shape `(batch_size, n)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
return torch.ones(p.shape[0], device=p.device)
[docs]
class ActivationSymplecticLayer(InvertibleLayer):
"""Combines a linear transformation on two input tensors :code:`y` and :code:`p`.
Applies an activation function, scales the result based on :code:`p`,
and returns a matrix product of the transformed tensors.
The module is used to model potential gradients in neural network architectures,
especially in problems involving structured data.
Args:
size: Total dimension of the state space (will be split into p and q).
conditional_size: Dimension of the conditional input tensor.
**kwargs: Additional keyword arguments. The activation function type can be
passed as a keyword argument (e.g., "tanh", "relu").
"""
def __init__(self, size: int, conditional_size: int, **kwargs):
InvertibleLayer.__init__(self, size, conditional_size)
self.n = size // 2
#: Linear transformation for the `y` input tensor.
self.linear_a: nn.Linear = nn.Linear(self.n, 1, bias=False)
self.act = activation_function(kwargs.get("activation_type", "tanh"), **kwargs)
[docs]
def forward(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the forward pass.
This method combines the transformations of the input tensors,
applies an activation function, scales the result,
and returns the matrix product.
Args:
p: The momentum tensor.
q: The position tensor.
mu: The conditional input tensor.
Returns:
The output tensor after applying the transformation and scaling.
"""
p = p + self.linear_a.weight.squeeze() * self.act(q)
return torch.cat((p, q), dim=-1)
[docs]
def backward(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the backward (inverse) pass.
This method inverts the transformation applied in the forward pass.
Args:
p: The momentum tensor of dimension `(batch_size, n)`.
q: The position tensor of dimension `(batch_size, n)`.
mu: The conditional input tensor of dimension
`(batch_size, conditional_size)`.
Returns:
The concatenated tensor of the original `(p, q)`.
"""
# Forward: p_new = p + diag(linear_a.weight) * act(q)
# Backward: p_old = p_new - diag(linear_a.weight) * act(q)
p = p - self.linear_a.weight.squeeze() * self.act(q)
return torch.cat((p, q), dim=-1)
[docs]
def log_abs_det_jacobian(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the log absolute value of the determinant of the Jacobian.
Args:
p: the momentum tensor of shape `(batch_size, n)`.
q: the position tensor of shape `(batch_size, n)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The log absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
# The Jacobian is [[I, diag(a)*act'(q)], [0, I]] which has determinant 1
return torch.zeros(p.shape[0], device=p.device)
[docs]
def abs_det_jacobian(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the absolute value of the determinant of the Jacobian.
Args:
p: the momentum tensor of shape `(batch_size, n)`.
q: the position tensor of shape `(batch_size, n)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
# Symplectic transformations preserve volume, determinant = 1
return torch.ones(p.shape[0], device=p.device)
[docs]
class GradPotentialSymplecticLayer(InvertibleLayer):
"""Combines a linear transformation on two input tensors :code:`y` and :code:`p`.
Applies an activation function, scales the result based on :code:`p`,
and returns a matrix product of the transformed tensors.
The module is used to model potential gradients in neural network architectures,
especially in problems involving structured data.
Args:
size: Total dimension of the state space.
conditional_size: Dimension of the conditional input tensor.
width: Width of the internal layers (i.e., the number of units in the hidden
layers).
**kwargs: Additional keyword arguments. The activation function type can be
passed as a keyword argument (e.g., "tanh", "relu").
"""
def __init__(self, size: int, conditional_size: int, width: int, **kwargs):
InvertibleLayer.__init__(self, size, conditional_size)
self.width = width
#: Linear transformation for the `y` input tensor.
self.linear_q: nn.Linear = nn.Linear(size, width, bias=False)
#: Linear transformation for the `p` input tensor.
self.linear_mu: nn.Linear = nn.Linear(conditional_size, width)
#: Activation function type (e.g., 'tanh') applied to the sum of the linear
#: transformations.
self.activation_type: str = kwargs.get("activation", "tanh")
#: Linear scaling transformation for the `p` tensor.
self.scaling: nn.Linear = nn.Linear(conditional_size, width)
#: Activation function applied to the sum of the linear transformations.
self.activation = activation_function(self.activation_type, **kwargs)
[docs]
def forward(
self, q: torch.Tensor, p: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the forward pass.
This method combines the transformations of the input tensors,
applies an activation function, scales the result,
and returns the matrix product.
Args:
q: The position tensor.
p: The momentum tensor.
mu: The conditional input tensor.
Returns:
The output tensor after applying the transformation and scaling.
"""
p_int = self.activation(self.linear_q(q) + self.linear_mu(mu))
p = p + (self.scaling(mu) * p_int) @ self.linear_q.weight
return q, p
[docs]
def backward(
self, q: torch.Tensor, p: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the backward (inverse) pass.
This method inverts the transformation applied in the forward pass.
Args:
q: The position tensor of dimension `(batch_size, n)`.
p: The momentum tensor of dimension `(batch_size, n)`.
mu: The conditional input tensor of dimension
`(batch_size, conditional_size)`.
Returns:
The concatenated tensor of the original `(p, q)`.
"""
# Forward: p_new = p + (scaling(mu) * activation(...)) @ linear_q.weight
# Backward: p_old = p_new - (scaling(mu) * activation(...)) @ linear_q.weight
p_int = self.activation(self.linear_q(q) + self.linear_mu(mu))
p = p - (self.scaling(mu) * p_int) @ self.linear_q.weight
return q, p
[docs]
def log_abs_det_jacobian(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the log absolute value of the determinant of the Jacobian.
Args:
p: the momentum tensor of shape `(batch_size, n)`.
q: the position tensor of shape `(batch_size, n)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The log absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
# The Jacobian is [[I, grad_function], [0, I]] which has determinant 1
return torch.zeros(p.shape[0], device=p.device)
[docs]
def abs_det_jacobian(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the absolute value of the determinant of the Jacobian.
Args:
p: the momentum tensor of shape `(batch_size, n)`.
q: the position tensor of shape `(batch_size, n)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
# Symplectic transformations preserve volume, determinant = 1
return torch.ones(p.shape[0], device=p.device)
[docs]
class PeriodicGradPotentialSymplecticLayer(InvertibleLayer):
"""Combines a linear transformation on two input tensors :code:`y` and :code:`p`.
Applies an activation function, scales the result based on :code:`p`,
and returns a matrix product of the transformed tensors.
The module is used to model periodic potential gradients
in neural network architectures,
especially in problems involving structured data.
Args:
size: Total dimension of the state space.
conditional_size: Dimension of the conditional input tensor.
width: Width of the internal layers (i.e., the number of units in the hidden
layers).
period: The period of the potential.
**kwargs: Additional keyword arguments. The activation function type can be
passed as a keyword argument (e.g., "tanh", "relu").
"""
def __init__(
self,
size: int,
conditional_size: int,
width: int,
period: torch.Tensor,
**kwargs,
):
InvertibleLayer.__init__(self, size, conditional_size)
self.width = width
#: Linear transformation for the `y` input tensor.
self.linear_q1: nn.Linear = nn.Linear(size, width, bias=False)
self.linear_q2: nn.Linear = nn.Linear(size, width, bias=False)
self.b1 = nn.Parameter(torch.zeros(width))
self.b2 = nn.Parameter(torch.zeros(width))
#: Linear transformation for the `p` input tensor.
self.linear_mu: nn.Linear = nn.Linear(conditional_size, width)
#: Activation function type (e.g., 'tanh') applied to the sum of the linear
#: transformations.
self.activation_type: str = kwargs.get("activation", "tanh")
#: Linear scaling transformation for the `p` tensor.
self.scaling: nn.Linear = nn.Linear(conditional_size, width)
#: Activation function applied to the sum of the linear transformations.
self.activation = activation_function(self.activation_type, **kwargs)
self.L = period
[docs]
def forward(
self, q: torch.Tensor, p: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the forward pass.
This method combines the transformations of the input tensors,
applies an activation function, scales the result,
and returns the matrix product.
Args:
q: The position tensor.
p: The momentum tensor.
mu: The conditional input tensor.
Returns:
The output tensor after applying the transformation and scaling.
"""
# p_int = self.activation(self.linear_q1(q) + self.linear_mu(mu))
# p = p + (self.scaling(mu) * p_int) @ self.linear_q.weight
phase = 2 * torch.pi * q / self.L
cos_phase = torch.cos(phase)
sin_phase = torch.sin(phase)
z1 = self.linear_q1(cos_phase) + self.b1
z2 = self.linear_q2(sin_phase) + self.b2 # (batch_size, width)
p_int1 = self.activation(z1 + self.linear_mu(mu))
p_int2 = self.activation(z2 + self.linear_mu(mu))
# result = p_int2 @ self.linear_q2.weight * cos_phase
# result -= p_int1 @ self.linear_q1.weight* sin_phase
result = ((self.scaling(mu) * p_int2) @ self.linear_q2.weight) * cos_phase
result -= ((self.scaling(mu) * p_int1) @ self.linear_q1.weight) * sin_phase
p = p + result
return q, p
[docs]
def backward(
self, q: torch.Tensor, p: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the backward (inverse) pass.
This method inverts the transformation applied in the forward pass.
Args:
q: The position tensor of dimension `(batch_size, n)`.
p: The momentum tensor of dimension `(batch_size, n)`.
mu: The conditional input tensor of dimension
`(batch_size, conditional_size)`.
Returns:
The concatenated tensor of the original `(p, q)`.
"""
# Forward: p_new = p + (scaling(mu) * activation(...)) @ linear_q.weight
# Backward: p_old = p_new - (scaling(mu) * activation(...)) @ linear_q.weight
# p_int = self.activation(self.linear_q(q) + self.linear_mu(mu))
# p = p - (self.scaling(mu) * p_int) @ self.linear_q.weight
phase = 2 * torch.pi * q / self.L
cos_phase = torch.cos(phase)
sin_phase = torch.sin(phase)
z1 = self.linear_q1(cos_phase) + self.b1
z2 = self.linear_q2(sin_phase) + self.b2 # (batch_size, width)
p_int1 = self.activation(z1 + self.linear_mu(mu))
p_int2 = self.activation(z2 + self.linear_mu(mu))
result = ((self.scaling(mu) * p_int2) @ self.linear_q2.weight) * cos_phase
result -= ((self.scaling(mu) * p_int1) @ self.linear_q1.weight) * sin_phase
p = p - result
return q, p
[docs]
def log_abs_det_jacobian(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the log absolute value of the determinant of the Jacobian.
Args:
p: the momentum tensor of shape `(batch_size, n)`.
q: the position tensor of shape `(batch_size, n)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The log absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
# The Jacobian is [[I, grad_function], [0, I]] which has determinant 1
return torch.zeros(p.shape[0], device=p.device)
[docs]
def abs_det_jacobian(
self, p: torch.Tensor, q: torch.Tensor, mu: torch.Tensor
) -> torch.Tensor:
"""Computes the absolute value of the determinant of the Jacobian.
Args:
p: the momentum tensor of shape `(batch_size, n)`.
q: the position tensor of shape `(batch_size, n)`.
mu: the conditional input tensor of shape `(batch_size, conditional_size)`.
Returns:
The absolute determinant of the Jacobian as a tensor of
shape `(batch_size,)`.
"""
# Symplectic transformations preserve volume, determinant = 1
return torch.ones(p.shape[0], device=p.device)