Source code for scimba_torch.neural_nets.structure_preserving_nets.ode_splitted_layer
"""ODE-based splitted layer for invertible networks."""from__future__importannotationsfromabcimportabstractmethodimporttorchfromscimba_torch.neural_nets.structure_preserving_nets.invertible_nnimport(InvertibleLayer,)
[docs]classODESplittedLayer(InvertibleLayer):"""Abstract class for ODE-based invertible layers that operate on split tensors. This layer processes one part of a split tensor while being conditioned on: - The conditional input (mu) - Other parts of the split tensor (specified by indices) Args: size: dimension of the input part to process conditional_size: dimension of the conditional input data split_index: index of the part this layer processes (0-based) other_indices: list of indices of other parts to use as conditioning **kwargs: other arguments for the invertible layer """def__init__(self,size:int,conditional_size:int,split_index:int,other_indices:list[int],**kwargs,):super().__init__(size,conditional_size,**kwargs)self.split_index=split_indexself.other_indices=other_indices
[docs]@abstractmethoddefforward(self,y:torch.Tensor,mu:torch.Tensor,other_parts:list[torch.Tensor],with_last_layer:bool=True,)->torch.Tensor:"""Forward pass of the ODE-based layer. Args: y: the input tensor part to transform, shape `(batch_size, size)` mu: the conditional input tensor, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts used for conditioning with_last_layer: whether to use the last layer of the network Returns: The transformed tensor of shape `(batch_size, size)` """
[docs]@abstractmethoddefbackward(self,y:torch.Tensor,mu:torch.Tensor,other_parts:list[torch.Tensor],with_last_layer:bool=True,)->torch.Tensor:"""Backward pass (inverse) of the ODE-based layer. Args: y: the input tensor part to transform, shape `(batch_size, size)` mu: the conditional input tensor, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts used for conditioning with_last_layer: whether to use the last layer of the network Returns: The inverse-transformed tensor of shape `(batch_size, size)` """
[docs]@abstractmethoddeflog_abs_det_jacobian(self,y:torch.Tensor,mu:torch.Tensor,other_parts:list[torch.Tensor],)->torch.Tensor:"""Computes the log absolute determinant of the Jacobian. Args: y: the input tensor part, shape `(batch_size, size)` mu: the conditional input tensor, shape `(batch_size, conditional_size)` other_parts: list of other tensor parts used for conditioning Returns: The log absolute determinant as a tensor of shape `(batch_size,)` """