Source code for scimba_torch.neural_nets.coordinates_based_nets.siren

"""Siren architecture implementation."""

import numpy as np
import torch
from torch import nn

from scimba_torch.neural_nets.coordinates_based_nets.scimba_module import ScimbaModule

from .activation import Sine


[docs] class SirenLayer(nn.Module): """Class representing a Siren Layer. Args: in_size: Dimension of the inputs. out_size: Dimension of the outputs. w0: Frequency parameter. Defaults to 1. c: Initialization parameter. Defaults to 6. is_first: Whether this is the first layer. Defaults to False. use_bias: Whether to use bias. Defaults to True. """ def __init__( self, in_size: int, out_size: int, w0: int = 1, c: int = 6, is_first: bool = False, use_bias: bool = True, ): super().__init__() self.in_size = in_size #: Dimension of the inputs. self.is_first = is_first #: Whether this is the first layer. self.out_size = out_size #: Dimension of the outputs. self.layer = nn.Linear( in_size, out_size, bias=use_bias ) #: The linear layer applied to the vector of features. self.init_(self.layer.weight, self.layer.bias, c=c, w0=w0) self.activation = Sine(freq=w0) #: The sine activation function.
[docs] def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: int, w0: int): """Init the weights of the layer using the specific Siren initialization. Args: weight: The weight of the layer to initialize. bias: The bias of the layer to initialize. c: A parameter for the weight initialization. w0: The frequency of the sinus activation function. """ dim = self.in_size w_std = (1 / dim) if self.is_first else (np.sqrt(c / dim) / w0) torch.nn.init.uniform_(weight, -w_std, w_std) if bias is not None: torch.nn.init.uniform_(bias, -w_std, w_std)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply the network to the inputs. Args: x: Input tensor. Returns: The result of the layer. """ return self.activation(self.layer(x))
[docs] class SirenNet(ScimbaModule): """Class representing a Siren architecture with optional ResNet layers. Args: in_size: dimension of inputs out_size: dimension of outputs **kwargs: Additional keyword arguments: - `w` (:code:`int`): frequency for the internal layers' activation functions - `w0` (:code:`int`): frequency for the first layer's activation function - `layer_sizes` (:code:`list[int]`): list of the size for each layer - `use_res_net` (:code:`bool`, default=False): whether to use a ResNet architecture """ def __init__(self, in_size: int, out_size: int, **kwargs): super().__init__(in_size, out_size, **kwargs) self.layer_sizes = kwargs.get("layer_sizes", [20, 20, 20]) self.w = kwargs.get("w", 1) self.w0 = kwargs.get("w0", 30) #: list of Siren layers, potentially with residual connections self.layers = nn.ModuleList([]) # First layer (special initialization) self.layers.append( SirenLayer( in_size=self.in_size, out_size=self.layer_sizes[0], w0=self.w0, use_bias=True, is_first=True, ) ) # Hidden layers for i in range(1, len(self.layer_sizes) - 1): self.layers.append( SirenLayer( in_size=self.layer_sizes[i], out_size=self.layer_sizes[i + 1], w0=self.w, use_bias=True, is_first=False, ) ) # Output layer self.output_layer = nn.Linear(self.layer_sizes[-1], self.out_size)
[docs] def forward(self, x: torch.Tensor, with_last_layer: bool = True) -> torch.Tensor: """Apply the network to the inputs x. Args: x: input tensor with_last_layer: Whether to apply the final output layer Returns: the result of the network """ for layer in self.layers: x = layer(x) # Output layer if with_last_layer: x = self.output_layer(x) return x