"""Symplectic neural networks composed of invertible layers."""
import torch
from torch import nn
from scimba_torch.neural_nets.structure_preserving_nets.coupling_symplectic_layers import ( # noqa: E501
GSymplecticLayer,
LASymplecticLayer,
PeriodicGSymplecticLayer,
)
from scimba_torch.neural_nets.structure_preserving_nets.invertible_nn import (
InvertibleNet,
)
from scimba_torch.neural_nets.structure_preserving_nets.nilpotent_symplectic_layer import ( # noqa: E501
PSymplecticLayer,
)
[docs]
class GSymplecticNet(InvertibleNet):
"""An invertible neural network composed of multiple invertible layers.
Args:
size: dimension of the input data
conditional_size: dimension of the conditional input data
width: width of the hidden layers in each layer
nb_layers: number of invertible layers to compose
**kwargs: other arguments for the invertible layers
"""
def __init__(
self,
size: int,
conditional_size: int,
width: int,
nb_layers: int,
**kwargs,
):
super().__init__(size=size, conditional_size=conditional_size)
self.size = size
self.conditional_size = conditional_size
self.nb_layers = nb_layers
self.width = width
self.layers = nn.ModuleList(
[
GSymplecticLayer(size, conditional_size, width, **kwargs)
for _ in range(self.nb_layers)
]
)
[docs]
def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the log absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input of shape `(batch_size, conditional_size)`.
Returns:
Log determinant of shape `(batch_size,)`.
"""
# G-symplectic transformations preserve volume, determinant = 1
return torch.zeros(y.shape[0], device=y.device)
[docs]
def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input of shape `(batch_size, conditional_size)`.
Returns:
Determinant of shape `(batch_size,)`.
"""
return torch.ones(y.shape[0], device=y.device)
[docs]
class PeriodicGSymplecticNet(InvertibleNet):
"""An invertible neural network composed of multiple invertible layers.
Args:
size: dimension of the input data
conditional_size: dimension of the conditional input data
width: width of the hidden layers in each layer
nb_layers: number of invertible layers to compose
**kwargs: other arguments for the invertible layers
"""
def __init__(
self,
size: int,
conditional_size: int,
width: int,
nb_layers: int,
**kwargs,
):
super().__init__(size=size, conditional_size=conditional_size)
self.size = size
self.conditional_size = conditional_size
self.nb_layers = nb_layers
self.width = width
self.period = kwargs.pop("period", None)
self.layers = nn.ModuleList(
[
PeriodicGSymplecticLayer(
size, conditional_size, width, period=self.period, **kwargs
)
for _ in range(self.nb_layers)
]
)
[docs]
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Applies all layers in sequence.
Args:
inputs: Input tensor of shape `(2,batch_size, size)`.
Returns:
Transformed tensor of shape `(batch_size, size)`.
"""
y = inputs[:, : self.size]
mu = inputs[:, self.size :]
print(f"inputs.shape = {inputs.shape}")
for layer in self.layers:
y = layer.forward(y, mu)
if self.period is not None:
split_idx = self.size // 2
q = y[:, :split_idx]
p = y[:, split_idx:]
q = q % self.period
y = torch.cat([q, p], dim=-1)
return y
[docs]
def backward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Applies all layers in reverse order.
Args:
inputs: Input tensor of shape `(2, batch_size, size)`.
Returns:
Inverse transformed tensor of shape `(batch_size, size)`.
"""
y = inputs[:, : self.size]
mu = inputs[:, self.size :]
for layer in reversed(self.layers):
y = layer.backward(y, mu)
if self.period is not None:
split_idx = self.size // 2
q = y[:, :split_idx]
p = y[:, split_idx:]
q = q % self.period
y = torch.cat([q, p], dim=-1)
return y
[docs]
def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the log absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input of shape `(batch_size, conditional_size)`.
Returns:
Log determinant of shape `(batch_size,)`.
"""
# G-symplectic transformations preserve volume, determinant = 1
return torch.zeros(y.shape[0], device=y.device)
[docs]
def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input of shape `(batch_size, conditional_size)`.
Returns:
Determinant of shape `(batch_size,)`.
"""
return torch.ones(y.shape[0], device=y.device)
[docs]
class LASymplecticNet(InvertibleNet):
"""An invertible neural network composed of multiple invertible layers.
Args:
size: dimension of the input data
conditional_size: dimension of the conditional input data
width: width of the hidden layers in each layer
nb_layers: number of invertible layers to compose
**kwargs: other arguments for the invertible layers
"""
def __init__(
self,
size: int,
conditional_size: int,
width: int,
nb_layers: int,
**kwargs,
):
self.size = size
self.conditional_size = conditional_size
self.nb_layers = nb_layers
self.width = width
self.layers = nn.ModuleList(
[
LASymplecticLayer(size, conditional_size, width, **kwargs)
for _ in range(self.nb_layers)
]
)
[docs]
def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the log absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input of shape `(batch_size, conditional_size)`.
Returns:
Log determinant of shape `(batch_size,)`.
"""
# G-symplectic transformations preserve volume, determinant = 1
return torch.zeros(y.shape[0], device=y.device)
[docs]
def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input of shape `(batch_size, conditional_size)`.
Returns:
Determinant of shape `(batch_size,)`.
"""
return torch.ones(y.shape[0], device=y.device)
[docs]
class PSymplecticNet(InvertibleNet):
"""An invertible neural network composed of multiple invertible layers.
Args:
size: dimension of the input data
conditional_size: dimension of the conditional input data
deg: degree of the polynomial transformation
nb_layers: number of invertible layers to compose
**kwargs: other arguments for the invertible layers
"""
def __init__(
self,
size: int,
conditional_size: int,
deg: int,
nb_layers: int,
**kwargs,
):
super().__init__(size=size, conditional_size=conditional_size)
self.size = size
self.conditional_size = conditional_size
self.nb_layers = nb_layers
self.deg = deg
self.layers = nn.ModuleList(
[
PSymplecticLayer(size, conditional_size, deg, **kwargs)
for _ in range(self.nb_layers)
]
)
[docs]
def log_abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the log absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input of shape `(batch_size, conditional_size)`.
Returns:
Log determinant of shape `(batch_size,)`.
"""
# G-symplectic transformations preserve volume, determinant = 1
return torch.zeros(y.shape[0], device=y.device)
[docs]
def abs_det_jacobian(self, y: torch.Tensor, mu: torch.Tensor) -> torch.Tensor:
"""Computes the absolute determinant of the Jacobian.
Args:
y: Input tensor of shape `(batch_size, size)`.
mu: Conditional input of shape `(batch_size, conditional_size)`.
Returns:
Determinant of shape `(batch_size,)`.
"""
return torch.ones(y.shape[0], device=y.device)