Source code for scimba_torch.neural_nets.embeddings.periodic_embedding

"""Periodic and flipped embeddings."""

import torch


[docs] class PeriodicEmbedding(torch.nn.Module): """Creates a one-layer network to model a periodic embedding of the input data. The learnable parameters are the weights, phases and biases of the periodic functions. Args: in_size: dimension of inputs out_size: dimension of outputs periods: periods of the periodic functions """ def __init__(self, in_size: int, out_size: int, periods: torch.Tensor): super().__init__() self.in_size = in_size self.out_size = out_size self.periods = periods weight = torch.randn(1, in_size, out_size) phase = torch.randn(1, in_size, out_size) bias = torch.randn(1, in_size, out_size) self.weight: torch.nn.Parameter = torch.nn.Parameter( weight ) #: the weights of the layer self.phase: torch.nn.Parameter = torch.nn.Parameter( phase ) #: the phase of the layer self.bias: torch.nn.Parameter = torch.nn.Parameter( bias ) #: the bias of the layer
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of the periodic embedding layer. Args: x: input tensor Returns: output tensor """ # ensures shape [N, p] (in case of natural gradient) x = x.reshape(-1, x.shape[-1]) # [N, p, 1] + [1, p, d] => [N, p, d] arg = (2 * torch.pi * x / self.periods)[..., None] + self.phase # [1, p, d] * [N, p, d] + [1, p, d] out = self.weight * torch.cos(arg) + self.bias # sum over p → shape [N, d] out = out.sum(dim=1) # squeeze to remove the first dimension if it is 1 return out.squeeze(0) if x.shape[0] == 1 else out
[docs] class FlippedEmbedding(torch.nn.Module): """Creates a one-layer network to model a flipped embedding of the input data. It is only available for 2D inputs on the unit square. Args: in_size: dimension of inputs out_size: dimension of outputs """ def __init__(self, in_size: int, out_size: int): super().__init__() self.in_size = in_size self.out_size = out_size weight = 1 + torch.randn(1, in_size, out_size) / 10 bias = 1.5 + torch.randn(1, in_size, out_size) / 5 coeff_x1 = 2 + torch.randn(1, in_size, out_size) / 5 coeff_x2 = 2 + torch.randn(1, in_size, out_size) / 5 self.weight = torch.nn.Parameter(weight) self.bias = torch.nn.Parameter(bias) self.coeff_x1 = torch.nn.Parameter(coeff_x1) self.coeff_x2 = torch.nn.Parameter(coeff_x2)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of the flipped embedding layer. Args: x: input tensor Returns: output tensor """ x1, x2 = x[..., 0, None, None], x[..., 1, None, None] x1_ = self.coeff_x1 * (x1 - 1 / 2) x2_ = self.coeff_x2 * (x2 - 1 / 2) return (self.weight * torch.tanh(x1_) * torch.tanh(x2_) + self.bias).sum(dim=1)