scimba_torch.neural_nets.structure_preserving_nets.ode_splitted_layer¶
ODE-based splitted layer for invertible networks.
Classes
|
Abstract class for ODE-based invertible layers that operate on split tensors. |
- class ODESplittedLayer(size, conditional_size, split_index, other_indices, **kwargs)[source]¶
Bases:
InvertibleLayerAbstract class for ODE-based invertible layers that operate on split tensors.
This layer processes one part of a split tensor while being conditioned on: - The conditional input (mu) - Other parts of the split tensor (specified by indices)
- Parameters:
size (
int) – dimension of the input part to processconditional_size (
int) – dimension of the conditional input datasplit_index (
int) – index of the part this layer processes (0-based)other_indices (
list[int]) – list of indices of other parts to use as conditioning**kwargs – other arguments for the invertible layer
- abstract forward(y, mu, other_parts, with_last_layer=True)[source]¶
Forward pass of the ODE-based layer.
- Parameters:
y (
Tensor) – the input tensor part to transform, shape (batch_size, size)mu (
Tensor) – the conditional input tensor, shape (batch_size, conditional_size)other_parts (
list[Tensor]) – list of other tensor parts used for conditioningwith_last_layer (
bool) – whether to use the last layer of the network
- Return type:
Tensor- Returns:
The transformed tensor of shape (batch_size, size)
- abstract backward(y, mu, other_parts, with_last_layer=True)[source]¶
Backward pass (inverse) of the ODE-based layer.
- Parameters:
y (
Tensor) – the input tensor part to transform, shape (batch_size, size)mu (
Tensor) – the conditional input tensor, shape (batch_size, conditional_size)other_parts (
list[Tensor]) – list of other tensor parts used for conditioningwith_last_layer (
bool) – whether to use the last layer of the network
- Return type:
Tensor- Returns:
The inverse-transformed tensor of shape (batch_size, size)
- abstract log_abs_det_jacobian(y, mu, other_parts)[source]¶
Computes the log absolute determinant of the Jacobian.
- Parameters:
y (
Tensor) – the input tensor part, shape (batch_size, size)mu (
Tensor) – the conditional input tensor, shape (batch_size, conditional_size)other_parts (
list[Tensor]) – list of other tensor parts used for conditioning
- Return type:
Tensor- Returns:
The log absolute determinant as a tensor of shape (batch_size,)