Source code for scimba_torch.neural_nets.coordinates_based_nets.features

"""Neural networks with feature transformations such as Fourier features."""

from typing import Any

import torch
from torch import nn

from scimba_torch.neural_nets.coordinates_based_nets.scimba_module import ScimbaModule

from ..embeddings.periodic_embedding import FlippedEmbedding, PeriodicEmbedding
from .activation import activation_function
from .mlp import GenericMLP
from .res_net import GenericResNet


[docs] class EnhancedFeatureNet(nn.Module): """A network that generates learnable features such as Fourier features. The weights are initialized using a normal distribution. Args: in_size: The input dimension (number of features in the input tensor). **kwargs: Additional keyword arguments: - `nb_features` (:code:`int`, default=1): The number of features generated by the network. - `type_feature` (:code:`str`, default="fourier"): The type of feature transformation to apply. - `mean` (:code:`float`, default=0.0): The mean used for initializing the weights. - `std` (:code:`float`, default=1.0): The standard deviation used for initializing the weights. """ def __init__(self, in_size: int, **kwargs: Any): super().__init__() self.in_size = in_size self.nb_features = kwargs.get("nb_features", 1) self.type_feature = kwargs.get("type_feature", "fourier") self.mean = kwargs.get("mean", 0.0) self.std = kwargs.get("std", 1.0) self.activation = kwargs.get("activation", "sine") # Layer initialization if self.type_feature == "periodic": assert "periods" in kwargs, ( "Periods must be provided for periodic features." ) self.layer = PeriodicEmbedding(in_size, self.nb_features, kwargs["periods"]) elif self.type_feature == "flipped": assert in_size == 2, "Flipped features are only available for 2D inputs." self.layer = FlippedEmbedding(in_size, self.nb_features) else: self.layer = nn.Linear(in_size, self.nb_features, bias=False) if self.type_feature == "fourier": # Fourier features initialization nn.init.normal_(self.layer.weight, self.mean, self.std) self.ac_sine = activation_function("sine", **kwargs) self.ac_cosine = activation_function("cosine", **kwargs) self.enhanced_dim = 2 * self.nb_features elif self.type_feature in ["periodic", "flipped"]: # Periodic or flipped features initialization self.enhanced_dim = self.nb_features
[docs] def re_init(self, mean: float, std: float): """Reinitialize the weights of the linear layer using a normal distribution. Args: mean: Mean value for normal distribution initialization. std: Standard deviation for normal distribution initialization. """ nn.init.normal_(self.layer.weight, mean, std)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the feature transformation to the input tensor `x`. Args: x: Input tensor. Returns: Transformed feature tensor. """ # Apply the linear transformation transformed_features = self.layer(x) # Depending on the feature type, apply Fourier transformations if self.type_feature == "fourier": # Apply sine and cosine transformations for Fourier features out_sine = self.ac_sine(transformed_features) out_cosine = self.ac_cosine(transformed_features) # print(out_sine.shape) # print(out_cosine.shape) out = torch.cat([out_sine, out_cosine], dim=-1) elif self.type_feature in ["periodic", "flipped"]: # For periodic features, use the output as is out = transformed_features else: # For other feature types (e.g., wavelets), use a default activation out = self.activation(transformed_features) return out
[docs] class GenericFeatureNet(ScimbaModule): """A template for a general network with enhanced features. A feature can be a periodic embedding, Fourier features, etc. """
[docs] def parameters( self, flag_scope: str = "all", flag_format: str = "list" ) -> list[nn.Parameter] | torch.Tensor: """Get parameters of the neural net. Args: flag_scope: Specifies which parameters to return. Options: 'all', 'last_layer', 'except_last_layer'. flag_format: Specifies the format Options: 'list', 'tensor'. Returns: A list of parameters or a single tensor containing all parameters. Raises: ValueError: If an unknown flag_scope or flag_format is provided. """ if flag_scope == "all": param_iter = super().parameters() elif flag_scope == "last_layer": param_iter = self.net.output_layer.parameters() elif flag_scope == "except_last_layer": param_iter = ( param for name, param in self.named_parameters() if not name.startswith("net.output_layer") ) else: raise ValueError(f"Unknown flag_scope: {flag_scope}") if flag_format == "list": return list(param_iter) elif flag_format == "tensor": return torch.nn.utils.parameters_to_vector(param_iter) else: raise ValueError(f"Unknown flag_format: {flag_format}")
[docs] def set_parameters(self, new_params: torch.Tensor, flag_scope: str = "all"): """Set parameters. Args: new_params: new parameters. flag_scope: 'all', 'last_layer', 'except_last_layer' Raises: ValueError: If an unknown flag_scope is provided. """ if flag_scope == "all": param_iter = super().parameters() elif flag_scope == "last_layer": param_iter = self.net.output_layer.parameters() elif flag_scope == "except_last_layer": param_iter = ( param for name, param in self.named_parameters() if not name.startswith("net.output_layer") ) else: raise ValueError(f"Unknown flag_scope: {flag_scope}") torch.nn.utils.vector_to_parameters(new_params, param_iter)
[docs] def re_init_features(self, mean: float, std: float): """Reinitialize the weights of the `EnhancedFeatureNet` layer. Use a normal distribution with the specified mean and standard deviation. Args: mean: Mean value for the normal distribution. std: Standard deviation for the normal distribution. """ self.features.re_init(mean, std)
[docs] class GenericFourierNet(GenericFeatureNet): """Combines Fourier feature transformations with a specified neural network. The network first generates enhanced features (such as Fourier features), concatenates them with the original input, and then passes the result through a user-specified network architecture. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Additional keyword arguments: - `nb_features` (:code:`int`, default=1): Number of features generated by EnhancedFeatureNet. - `type_feature` (:code:`str`, default="fourier"): Type of feature transformation - Other keyword arguments are passed to the `EnhancedFeatureNet` and to whichever network class is specified by the user. Learnable Parameters: - features (:code:`EnhancedFeatureNet`): A network that generates enhanced features such as Fourier features. - net: (:code:`ScimbaModule`): A neural network that processes the input and features. """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.nb_features = kwargs.get("nb_features", 1) self.type_feature = kwargs.get("type_feature", "fourier") self.features = EnhancedFeatureNet(in_size=in_size, **kwargs) self.inputs_size = self.in_size + self.features.enhanced_dim
[docs] def forward(self, x: torch.Tensor, with_last_layer: bool = True) -> torch.Tensor: """Compute the forward pass. Apply the feature transformation, concatenate the features with the original input, and pass the result through the neural network to produce the output. Args: x: Input tensor. with_last_layer: Whether to include the last layer in the forward pass. (default: True) Returns: Output tensor after passing through the neural network. """ features = self.features.forward(x) inputs = torch.cat([x, features], dim=-1) return self.net.forward(inputs, with_last_layer)
[docs] class FourierMLP(GenericFourierNet): """Combines Fourier feature transformations with a Multi-Layer Perceptron (MLP). The network first generates enhanced features (such as Fourier features), concatenates them with the original input, and then passes the result through a fully connected MLP. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Additional keyword arguments: - `nb_features` (:code:`int`, default=1): Number of features generated by EnhancedFeatureNet. - `type_feature` (:code:`str`, default="fourier"): Type of feature transformation applied. - Other keyword arguments are passed to EnhancedFeatureNet and GenericMLP classes. """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.net = GenericMLP( in_size=self.inputs_size, out_size=self.out_size, **kwargs )
[docs] class FourierResNet(GenericFourierNet): """Combines Fourier feature transformations with a ResNet architecture. The network first generates enhanced features (such as Fourier features), concatenates them with the original input, and then passes the result through a series of residual blocks. It is a specialization of `GenericFourierNet`. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Additional keyword arguments. """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.net = GenericResNet( in_size=self.inputs_size, out_size=self.out_size, **kwargs )
[docs] class GenericPeriodicNet(GenericFeatureNet): """A neural network that appends a periodic embedding before the first layer. The network first generates periodic features, passes the input through it, and then passes the result through the chosen architecture. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Additional keyword arguments: - `domain_bounds` (:code:`required`): The bounds of the domain for the periodic features. - `nb_features` (:code:`optional`): The number of features generated by the `EnhancedFeatureNet`. - `type_feature` (:code:`optional`): The type of feature transformation applied (e.g., "periodic"). - Other arguments passed to the `EnhancedFeatureNet` class and whichever network class is specified by the user. Raises: KeyError: If `domain_bounds` is not provided in `kwargs`. ValueError: If `domain_bounds` is not a `torch.Tensor` or has incorrect shape. Learnable Parameters: - features (:code:`EnhancedFeatureNet`): A network that generates enhanced features such as Fourier features. - net: (:code:`ScimbaModule`): A neural network that processes the input and features. """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) layer_sizes = kwargs.get("layer_sizes", [20] * 6) if "domain_bounds" not in kwargs: raise KeyError("Domain bounds must be provided for periodic features.") assert "domain_bounds" in kwargs, ( "Domain bounds must be provided for periodic features." ) domain_bounds = kwargs["domain_bounds"] if not isinstance(domain_bounds, torch.Tensor): raise ValueError("domain_bounds argument must be a torch.Tensor") assert isinstance(domain_bounds, torch.Tensor) if not ((domain_bounds.shape[0] == 1) or (domain_bounds.shape[0] == in_size)): raise ValueError( "domain_bounds argument must be a (1,2) or (%d,2)-shaped torch.Tensor" % in_size ) kwargs["nb_features"] = layer_sizes[0] kwargs["type_feature"] = "periodic" kwargs["periods"] = self.compute_periods(domain_bounds) self.features = EnhancedFeatureNet(in_size=in_size, **kwargs) self.inputs_size = self.features.enhanced_dim kwargs["layer_sizes"] = layer_sizes[1:]
[docs] def compute_periods(self, domain_bounds: torch.Tensor) -> torch.Tensor: """Compute the periods for the periodic embedding from the domain bounds. Args: domain_bounds: The bounds of the domain for periodic features. Returns: torch.Tensor: The computed periods. """ lower, upper = domain_bounds.T return upper - lower
[docs] def forward(self, x: torch.Tensor, with_last_layer: bool = True) -> torch.Tensor: """Compute the forward pass. Apply the periodic transformation and pass the result through the MLP to produce the output. Args: x: Input tensor. with_last_layer: Whether to include the last layer in the forward pass. (default: True) Returns: Output tensor after passing through the neural network. """ return self.net.forward(self.features.forward(x), with_last_layer)
[docs] class PeriodicMLP(GenericPeriodicNet): """A neural network that combines periodic feature transformations with an MLP. The network first generates periodic features, passes the input through it, and then passes the result through a fully connected MLP. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Other keyword arguments including: - `domain_bounds` (:code:`required`): The bounds of the domain for the periodic features. - `nb_features` (:code:`optional`): Number of features generated by EnhancedFeatureNet. - `type_feature` (:code:`optional`): Type of feature transformation applied (e.g., "periodic"). - Other arguments passed to EnhancedFeatureNet and GenericMLP classes. """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.net = GenericMLP( in_size=self.inputs_size, out_size=self.out_size, **kwargs )
[docs] class PeriodicResNet(GenericPeriodicNet): """Combine a periodic feature transformations with a ResNet architecture. The network first generates periodic features, passes the input through it, and then passes the result through a ResNet. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Other keyword arguments including: - `domain_bounds` (:code:`required`): The bounds of the domain for the periodic features. - `nb_features` (:code:`optional`): The number of features generated by the `EnhancedFeatureNet`. - `type_feature` (:code:`optional`): The type of feature transformation applied (e.g., "periodic"). - Other arguments passed to the `EnhancedFeatureNet` and `GenericResNet` classes """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.net = GenericResNet( in_size=self.inputs_size, out_size=self.out_size, **kwargs )
[docs] class FlippedMLP(GenericFeatureNet): """Combine flipping feature transformations with a Multi-Layer Perceptron (MLP). The network first generates flipped features, passes the input through it, and then passes the result through a fully connected MLP. This class is only available for 2D inputs on the unit square. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Other keyword arguments including: - `nb_features` (:code:`optional`): Number of features generated by EnhancedFeatureNet. - `type_feature` (:code:`optional`): Type of feature transformation applied (e.g., "flipped"). - Other arguments passed to EnhancedFeatureNet and GenericMLP classes. """ def __init__(self, in_size: int, out_size: int, **kwargs): assert in_size == 2, "Flipped features are only available for 2D inputs." super().__init__(in_size, out_size, **kwargs) layer_sizes = kwargs.get("layer_sizes", [20] * 6) kwargs["nb_features"] = layer_sizes[0] kwargs["type_feature"] = "flipped" self.features = EnhancedFeatureNet(in_size=in_size, **kwargs) self.inputs_size = self.features.enhanced_dim kwargs["layer_sizes"] = layer_sizes[1:] self.net = GenericMLP( in_size=self.inputs_size, out_size=self.out_size, **kwargs )
[docs] def forward(self, x: torch.Tensor, with_last_layer: bool = True) -> torch.Tensor: """Compute the forward pass. Apply the periodic transformation and pass the result through the MLP to produce the output. Args: x: Input tensor. with_last_layer: Whether to include the last layer in the forward pass (default: True) Returns: Output tensor after passing through the MLP. """ return self.net.forward(self.features.forward(x), with_last_layer)
[docs] class GenericMultiScaleFourierNet(ScimbaModule): """Combines Fourier feature transformations with a specified neural network. The network first generates enhanced features (such as Fourier features), concatenates them with the original input, and then passes the result through a user-specified network architecture. The result is obtained as a linear combination of the Fourier networks. Args: in_size: int out_size: int **kwargs: Additional keyword arguments: - `means` (:code:`list[float]`, default=[0.0]): Initialize the weights of the EnhancedFeatureNet layers - `stds` (:code:`list[float]`, default=[1.0]): Initialize the weights of the EnhancedFeatureNet layers - `nb_features` (:code:`int`, default=1): The number of features generated by the EnhancedFeatureNet - `type_feature` (:code:`str`, default="fourier"): The type of feature transformation applied - Other keyword arguments are passed to the EnhancedFeatureNet and to whichever network class is specified by the user. Learnable Parameters: - features (:code:`EnhancedFeatureNet`): A network that generates enhanced features such as Fourier features. - net: (:code:`ScimbaModule`): A neural network that processes the input and features. """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.nb_features = kwargs.get("nb_features", 1) self.type_feature = kwargs.get("type_feature", "fourier") self.means = kwargs.get("means", [0.0]) self.stds = kwargs.get("stds", [1.0]) self.features = [ EnhancedFeatureNet(in_size=in_size, mean=mean, std=std, **kwargs) for mean, std in zip(self.means, self.stds) ] self.features = nn.ModuleList(self.features) self.inputs_size = self.in_size + self.features[0].enhanced_dim self.output_layer = nn.Linear(len(self.stds) * self.out_size, self.out_size)
[docs] def re_init_features(self, means: list[float], stds: list[float]): """Reinitialize the weights of the `EnhancedFeatureNet` layer. Use a normal distribution with the specified mean and standard deviation. Args: means: List of mean values for the normal distribution. stds: List of standard deviations for the normal distribution. """ for i, feat in enumerate(self.features): feat.re_init(means[i], stds[i])
[docs] def forward(self, x: torch.Tensor, with_last_layer: bool = True) -> torch.Tensor: """Compute the forward pass. Apply the feature transformation, concatenate the features with the original input and pass the result through the neural network to produce the output. Args: x: Input tensor. with_last_layer: Whether to include the last layer in the forward pass. (default: True) Returns: Output tensor after passing through the neural network. """ H = [ net.forward(torch.cat([x, feat.forward(x)], dim=-1), with_last_layer=True) for feat, net in zip(self.features, self.nets) ] H = torch.cat(H, dim=-1) if with_last_layer: H = self.output_layer.forward(H) return H
[docs] class MultiScaleFourierMLP(GenericMultiScaleFourierNet): """A linear combination of Fourier feature transformations with MLP. The networks first generate enhanced features (such as Fourier features), concatenate them with the original input, and then passes the result through a fully connected MLP. The result is obtained as a linear combination of the Fourier networks. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Additional keyword arguments: - `means` (:code:`list[float]`, default=[0.0]): Initialize the weights of the EnhancedFeatureNet layers - `stds` (:code:`list[float]`, default=[1.0]): Initialize the weights of the EnhancedFeatureNet layers - `nb_features` (:code:`int`, default=1): The number of features generated by the EnhancedFeatureNet - `type_feature` (:code:`str`, default="fourier"): The type of feature transformation applied """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.nets = [ GenericMLP(in_size=self.inputs_size, out_size=self.out_size, **kwargs) for _ in self.stds ] self.nets = nn.ModuleList(self.nets)
[docs] class MultiScaleFourierResNet(GenericMultiScaleFourierNet): """Combine Fourier feature transformations with Residual Networks (ResNet). The networks first generate enhanced features (such as Fourier features), concatenate them with the original input, and then passes the result through a fully connected ResNet. The result is obtained as a linear combination of the Fourier networks. It is a specialization of the generic class `GenericMultiScaleFourierNet`. Args: in_size: The input dimension (number of features in the input tensor). out_size: The output dimension (number of features in the output tensor). **kwargs: Additional keyword arguments. """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.nets = [ GenericResNet(in_size=self.inputs_size, out_size=self.out_size, **kwargs) for _ in self.stds ] self.nets = nn.ModuleList(self.nets)