"""Multi-Layer Perceptron (MLP) architectures."""
from typing import Callable
import torch
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from scimba_torch.neural_nets.coordinates_based_nets.scimba_module import ScimbaModule
from .activation import activation_function
[docs]
def factorized_glorot_normal(mean: float = 1.0, stddev: float = 0.1) -> Callable:
"""Initializes parameters.
Use a factorized version of the Glorot normal initialization.
Args:
mean: Mean of the log-normal distribution used to scale the singular values.
stddev: Standard deviation of the log-normal distribution.
Returns:
A function that takes a shape tuple and returns the factorized parameters `s`
and `v`.
Example:
>>> init_fn = factorized_glorot_normal()
>>> s, v = init_fn((64, 128))
"""
def init(shape: tuple) -> tuple[torch.Tensor, torch.Tensor]:
"""Inner function to initialize weights.
Args:
shape: Shape of the weight matrix (fan_in, fan_out).
Returns:
Two tensors:
- `s`: Scaling factors for each column (log-normal distributed).
- `v`: Normalized weight matrix after division by `s`.
"""
fan_in, fan_out = shape
std = (2.0 / (fan_in + fan_out)) ** 0.5
w = torch.randn(shape) * std
s = mean + torch.randn(shape[-1]) * stddev
s = torch.exp(s)
v = w / s
return s, v
return init
[docs]
class FactorizedLinear(nn.Module):
"""A linear transformation with factorized parameterization of the weights.
The weight matrix is expressed as the product of two factors:
- `s`: A column-wise scaling factor.
- `v`: A normalized weight matrix.
Args:
input_dim: Size of each input sample.
output_dim: Size of each output sample.
has_bias: Whether to include a bias term (default: True).
"""
def __init__(self, input_dim: int, output_dim: int, has_bias: bool = True):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.has_bias = has_bias
# Initialize kernel parameters (s and v) using factorized_glorot_normal
init_fn = factorized_glorot_normal()
s, v = init_fn((input_dim, output_dim))
self.s = nn.Parameter(s) #: Column-wise scaling factors
self.v = nn.Parameter(v) #: Normalized weight matrix
# Initialize bias
if self.has_bias:
#: Bias vector added to the output
self.bias = nn.Parameter(torch.zeros(output_dim))
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass of the FactorizedLinear layer.
Args:
x: Input tensor of shape (batch_size, input_dim).
Returns:
Output tensor of shape (batch_size, output_dim).
"""
kernel = self.s * self.v
y = torch.matmul(x, kernel)
if self.has_bias:
y = y + self.bias
return y
[docs]
class GenericMLP(ScimbaModule):
"""A general Multi-Layer Perceptron (MLP) architecture.
Args:
in_size: Dimension of the input
out_size: Dimension of the output
**kwargs: Additional keyword arguments:
- `activation_type` (:code:`str`, default="tanh"): The activation function
type to use in hidden layers.
- `activation_output` (:code:`str`, default="id"): The activation function
type to use in the output layer.
- `layer_sizes` (:code:`list[int]`, default=[20]*6): A list of integers
representing the number of neurons in each hidden layer.
- `weights_norm_bool` (:code:`bool`, default=False): If True, applies weight
normalization to the layers.
- `random_fact_weights_bool` (:code:`bool`, default=False): If True, applies
factorized weights to the layers.
Example:
>>> model = GenericMLP(10, 1, activation_type='relu', layer_sizes=[64, 128, 64])
"""
def __init__(self, in_size: int, out_size: int, **kwargs):
super().__init__(in_size, out_size, **kwargs)
activation_type = kwargs.get("activation_type", "tanh")
activation_type = kwargs.get("activation_type", "tanh")
activation_output = kwargs.get("activation_output", "id")
layer_sizes = kwargs.get("layer_sizes", [20] * 6)
weights_norm_bool = kwargs.get("weights_norm_bool", False)
random_fact_weights_bool = kwargs.get("random_fact_weights_bool", False)
self.last_layer_has_bias = kwargs.get("last_layer_has_bias", False)
self.in_size = in_size
self.out_size = out_size
self.layer_sizes = [in_size] + layer_sizes + [out_size]
#: A list of hidden linear layers.
self.hidden_layers = []
for l1, l2 in zip(self.layer_sizes[:-2], self.layer_sizes[+1:-1]):
if weights_norm_bool:
self.hidden_layers.append(weight_norm(nn.Linear(l1, l2)))
elif random_fact_weights_bool:
self.hidden_layers.append(FactorizedLinear(l1, l2))
else:
self.hidden_layers.append(nn.Linear(l1, l2))
self.hidden_layers = nn.ModuleList(self.hidden_layers)
if weights_norm_bool:
#: The final output linear layer.
self.output_layer = weight_norm(
nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.last_layer_has_bias,
)
)
elif random_fact_weights_bool:
self.output_layer = FactorizedLinear(
self.layer_sizes[-2],
self.layer_sizes[-1],
has_bias=self.last_layer_has_bias,
)
else:
self.output_layer = nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.last_layer_has_bias,
)
self.activation = []
for _ in range(len(self.layer_sizes) - 1):
self.activation.append(
activation_function(activation_type, in_size=in_size, **kwargs)
)
self.activation = nn.ModuleList(self.activation)
self.activation_output = activation_function(
activation_output, in_size=in_size, **kwargs
)
[docs]
def forward(
self, inputs: torch.Tensor, with_last_layer: bool = True
) -> torch.Tensor:
"""Apply the network to the inputs.
Args:
inputs: Input tensor
with_last_layer: Whether to apply the final output layer
Returns:
The result of the network
"""
for hidden_layer, activation in zip(self.hidden_layers, self.activation):
inputs = activation(hidden_layer(inputs))
if with_last_layer:
inputs = self.activation_output(self.output_layer(inputs))
return inputs
def __str__(self) -> str:
"""String representation of the model.
Returns:
A string describing the MLP network and its layer sizes.
"""
return f"MLP network with {self.layer_sizes} layers"
[docs]
def expand_hidden_layers(
self, new_layer_sizes: list[int], set_to_zero: bool = True
):
"""Expands the hidden layers of the MLP to new sizes.
The new sizes must match the number of hidden layers in the MLP.
The weights of the new layers are initialized to zero, and the weights of the
old layers are copied into the new layers.
The output layer is also expanded to match the new sizes.
Args:
new_layer_sizes: list of integers representing
the new sizes of the hidden layers.
set_to_zero: If True, initializes the weights of the
new layers to zero. Otherwise, set them to small random
values.
"""
assert len(new_layer_sizes) == len(self.hidden_layers), (
f"Expected {len(self.hidden_layers)} new sizes, got {len(new_layer_sizes)}."
)
# Calcule les nouvelles tailles d'entrée/sortie couche par couche
new_shapes = list(zip([self.in_size] + new_layer_sizes[:-1], new_layer_sizes))
updated_layers = []
for i, (new_in, new_out) in enumerate(new_shapes):
old_layer = self.hidden_layers[i]
old_in, old_out = old_layer.in_features, old_layer.out_features
# Nouvelle couche élargie
new_layer = nn.Linear(new_in, new_out)
with torch.no_grad():
# Zéro par défaut
new_layer.weight.zero_()
new_layer.bias.zero_()
# Copie dans le coin supérieur gauche
new_layer.weight[:old_out, :old_in] = old_layer.weight
new_layer.bias[:old_out] = old_layer.bias
if not set_to_zero:
new_layer.weight[:, old_in:] = torch.randn(
new_layer.weight[:, old_in:].shape
)
new_layer.bias[:, old_out:] = torch.randn(
new_layer.bias[:, old_out:].shape
)
updated_layers.append(new_layer)
# Nouvelle couche de sortie
new_output_layer = nn.Linear(
new_layer_sizes[-1], self.out_size, bias=self.last_layer_has_bias
)
with torch.no_grad():
old_w = self.output_layer.weight.data
new_output_layer.weight.zero_()
new_output_layer.weight[:, : old_w.shape[1]] = old_w
if not set_to_zero:
new_output_layer.weight[:, old_w.shape[1] :] = torch.randn(
new_output_layer.weight[:, old_w.shape[1] :].shape
)
if self.last_layer_has_bias:
old_b = self.output_layer.bias.data
new_output_layer.bias.copy_(old_b)
self.hidden_layers = nn.ModuleList(updated_layers)
self.output_layer = new_output_layer
self.layer_sizes = [self.in_size] + new_layer_sizes + [self.out_size]
[docs]
class MultiMLP(ScimbaModule):
"""A Multi-MLP architecture that creates a separate MLP for each output variable.
Each output variable is computed by its own MLP that takes all inputs and produces
a single output. The outputs are concatenated to form the final output.
Args:
in_size: Dimension of the input
out_size: Dimension of the output (number of output variables)
**kwargs: Additional keyword arguments:
- `activation_type` (:code:`str`, default="tanh"): The activation function
type to use in hidden layers.
- `activation_output` (:code:`str`, default="id"): The activation function
type to use in the output layer.
- `layer_sizes` (:code:`list[int]`, default=[20]*6): A list of integers
representing the number of neurons in each hidden layer per MLP.
- `weights_norm_bool` (:code:`bool`, default=False): If True, applies weight
normalization to the layers.
- `random_fact_weights_bool` (:code:`bool`, default=False): If True, applies
factorized weights to the layers.
Example:
>>> model = MultiMLP(3, 2, activation_type='relu', layer_sizes=[32, 64, 32])
>>> # Creates 2 MLPs, each taking 3 inputs and producing 1 output
"""
def __init__(self, in_size: int, out_size: int, **kwargs):
super().__init__(in_size, out_size, **kwargs)
self.in_size = in_size
self.out_size = out_size
#: A list of individual MLPs, one per output variable.
self.mlps = []
for _ in range(out_size):
# Create a MLP for each output variable (in_size -> 1)
mlp = GenericMLP(in_size, 1, **kwargs)
self.mlps.append(mlp)
self.mlps = nn.ModuleList(self.mlps)
[docs]
def forward(
self, inputs: torch.Tensor, with_last_layer: bool = True
) -> torch.Tensor:
"""Apply the network to the inputs.
Args:
inputs: Input tensor of shape (batch_size, in_size)
with_last_layer: Whether to apply the final output layer
Returns:
The result of the network of shape (batch_size, out_size)
"""
outputs = []
# Apply each MLP to the full input to get one output per MLP
for mlp in self.mlps:
# Apply the MLP to get one output
y_i = mlp(inputs, with_last_layer=with_last_layer)
outputs.append(y_i)
# Concatenate all outputs along the feature dimension
result = torch.cat(outputs, dim=-1)
return result
def __str__(self) -> str:
"""String representation of the model.
Returns:
A string describing the MultiMLP network.
"""
return (
f"MultiMLP network with {self.out_size} MLPs, "
f"each taking {self.in_size} inputs and outputting 1 dimension"
)
[docs]
class GenericMMLP(ScimbaModule):
"""A general Multiplicative Multi-Layer Perceptron (MMLP) architecture.
As proposed by Yanfei Xiang.
Args:
in_size: Dimension of the input
out_size: Dimension of the output
**kwargs: Additional keyword arguments:
- `activation_type` (:code:`str`, default="tanh"): The activation function
type to use in hidden layers.
- `activation_output` (:code:`str`, default="id"): The activation function
type to use in the output layer.
- `layer_sizes` (:code:`list[int]`, default=[10, 20, 20, 20, 5]): A list of
integers representing the number of neurons in each hidden layer.
- `weights_norm_bool` (:code:`bool`, default=False): If True, applies weight
normalization to the layers.
- `random_fact_weights_bool` (:code:`bool`, default=False): If True, applies
factorized weights to the layers.
Example:
>>> model = GenericMMLP(
... 10, 5, activation_type='relu', layer_sizes=[64, 128, 64]
... )
"""
def __init__(self, in_size: int, out_size: int, **kwargs):
super().__init__(in_size, out_size, **kwargs)
activation_type = kwargs.get("activation_type", "tanh")
activation_output = kwargs.get("activation_output", "id")
layer_sizes = kwargs.get("layer_sizes", [10, 20, 20, 20, 5])
weights_norm_bool = kwargs.get("weights_norm_bool", False)
random_fact_weights_bool = kwargs.get("random_fact_weights_bool", False)
self.layer_sizes = [in_size] + layer_sizes + [out_size]
#: A list of hidden linear layers.
self.hidden_layers = []
for l1, l2 in zip(self.layer_sizes[:-2], self.layer_sizes[+1:-1]):
if weights_norm_bool:
self.hidden_layers.append(weight_norm(nn.Linear(l1, l2)))
elif random_fact_weights_bool:
self.hidden_layers.append(FactorizedLinear(l1, l2))
else:
self.hidden_layers.append(nn.Linear(l1, l2))
self.hidden_layers = nn.ModuleList(self.hidden_layers)
#: A list of multiplicative linear layers.
self.hidden_layers_mult = []
for layer_size in self.layer_sizes[+1:-1]:
if weights_norm_bool:
self.hidden_layers_mult.append(
weight_norm(nn.Linear(self.in_size, layer_size))
)
elif random_fact_weights_bool:
self.hidden_layers_mult.append(
FactorizedLinear(self.in_size, layer_size)
)
else:
self.hidden_layers_mult.append(nn.Linear(self.in_size, layer_size))
self.hidden_layers_mult = nn.ModuleList(self.hidden_layers_mult)
if weights_norm_bool:
#: The final output linear layer.
self.output_layer = weight_norm(
nn.Linear(self.layer_sizes[-2], self.layer_sizes[-1])
)
elif random_fact_weights_bool:
self.output_layer = FactorizedLinear(
self.layer_sizes[-2], self.layer_sizes[-1]
)
else:
self.output_layer = nn.Linear(self.layer_sizes[-2], self.layer_sizes[-1])
self.activation = []
self.activation_mult = []
for _ in range(len(self.layer_sizes) - 1):
self.activation.append(
activation_function(activation_type, in_size=in_size, **kwargs)
)
self.activation_mult.append(
activation_function(activation_type, in_size=in_size, **kwargs)
)
self.activation_output = activation_function(
activation_output, in_size=in_size, **kwargs
)
[docs]
def forward(
self, inputs: torch.Tensor, with_last_layer: bool = True
) -> torch.Tensor:
"""Apply the network to the inputs.
Args:
inputs: Input tensor
with_last_layer: Whether to apply the final output layer (default: True)
Returns:
The result of the network
"""
multiplicators = []
for hidden_layer_mult, activation_mult in zip(
self.hidden_layers_mult,
self.activation_mult,
):
multiplicators.append(activation_mult(hidden_layer_mult(inputs)))
for hidden_layer, activation, multiplicator in zip(
self.hidden_layers, self.activation, multiplicators
):
inputs = multiplicator * activation(hidden_layer(inputs))
if with_last_layer:
inputs = self.activation_output(self.output_layer(inputs))
return inputs
def __str__(self) -> str:
"""String representation of the model.
Returns:
A string describing the MMLP network and its layer sizes.
"""
return f"MMLP network, with {self.layer_sizes} layers"
[docs]
class GenericModulationMLP(ScimbaModule):
"""A Multi-Layer Perceptron with modulation based on auxiliary input.
Each layer applies the transformation: gamma_l(y) * W_l * x + b_l(y)
where gamma_l and b_l are small modulation networks that take y as input.
Args:
x_size: Dimension of the main input x
y_size: Dimension of the modulation input y
out_size: Dimension of the output
**kwargs: Additional keyword arguments:
- `activation_type` (:code:`str`, default="tanh"): The activation function
type to use in hidden layers.
- `activation_output` (:code:`str`, default="id"): The activation function
type to use in the output layer.
- `layer_sizes` (:code:`list[int]`, default=[20]*6): A list of integers
representing the number of neurons in each hidden layer for x.
- `modulation_layer_sizes` (:code:`list[int]`, default=[10, 10]): A list of
integers representing the hidden layer sizes for gamma and b networks.
- `weights_norm_bool` (:code:`bool`, default=False): If True, applies weight
normalization to the layers.
- `random_fact_weights_bool` (:code:`bool`, default=False): If True, applies
factorized weights to the layers.
Example:
>>> model = GenericModulationMLP(
... x_size=3, y_size=2, out_size=1,
... layer_sizes=[64, 64], modulation_layer_sizes=[16, 16]
... )
"""
def __init__(self, x_size: int, y_size: int, out_size: int, **kwargs):
# Pass x_size as in_size for compatibility with ScimbaModule
super().__init__(x_size, out_size, **kwargs)
activation_type = kwargs.get("activation_type", "tanh")
activation_output = kwargs.get("activation_output", "id")
layer_sizes = kwargs.get("layer_sizes", [20] * 6)
modulation_layer_sizes = kwargs.get("modulation_layer_sizes", [10])
weights_norm_bool = kwargs.get("weights_norm_bool", False)
random_fact_weights_bool = kwargs.get("random_fact_weights_bool", False)
self.last_layer_has_bias = kwargs.get("last_layer_has_bias", False)
self.x_size = x_size
self.y_size = y_size
self.out_size = out_size
# Main pathway layer sizes for x: [x_size, hidden1, hidden2, ..., out_size]
self.layer_sizes = [x_size] + layer_sizes + [out_size]
# Create main linear layers (without bias since b_l(y) provides it)
self.hidden_layers = []
for l1, l2 in zip(self.layer_sizes[:-2], self.layer_sizes[1:-1]):
if weights_norm_bool:
self.hidden_layers.append(weight_norm(nn.Linear(l1, l2, bias=False)))
elif random_fact_weights_bool:
self.hidden_layers.append(FactorizedLinear(l1, l2, has_bias=False))
else:
self.hidden_layers.append(nn.Linear(l1, l2, bias=False))
self.hidden_layers = nn.ModuleList(self.hidden_layers)
# Output layer
if weights_norm_bool:
self.output_layer = weight_norm(
nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.last_layer_has_bias,
)
)
elif random_fact_weights_bool:
self.output_layer = FactorizedLinear(
self.layer_sizes[-2],
self.layer_sizes[-1],
has_bias=self.last_layer_has_bias,
)
else:
self.output_layer = nn.Linear(
self.layer_sizes[-2],
self.layer_sizes[-1],
bias=self.last_layer_has_bias,
)
# Create modulation networks for gamma (scalar multiplier per neuron)
self.gamma_networks = []
for layer_size in self.layer_sizes[1:-1]: # For each hidden layer
# Build a small MLP: y_size -> modulation_layer_sizes -> layer_size
layers = []
sizes = [y_size] + modulation_layer_sizes + [layer_size]
for i, (in_dim, out_dim) in enumerate(zip(sizes[:-1], sizes[1:])):
layers.append(nn.Linear(in_dim, out_dim))
# Add activation except for the last layer
if i < len(sizes) - 2:
layers.append(nn.Tanh())
self.gamma_networks.append(nn.Sequential(*layers))
self.gamma_networks = nn.ModuleList(self.gamma_networks)
# Create modulation networks for b (bias per neuron)
self.b_networks = []
for layer_size in self.layer_sizes[1:-1]: # For each hidden layer
# Build a small MLP: y_size -> modulation_layer_sizes -> layer_size
layers = []
sizes = [y_size] + modulation_layer_sizes + [layer_size]
for i, (in_dim, out_dim) in enumerate(zip(sizes[:-1], sizes[1:])):
layers.append(nn.Linear(in_dim, out_dim))
# Add activation except for the last layer
if i < len(sizes) - 2:
layers.append(nn.Tanh())
self.b_networks.append(nn.Sequential(*layers))
self.b_networks = nn.ModuleList(self.b_networks)
# Activation functions for main pathway
self.activation = []
for _ in range(len(self.layer_sizes) - 1):
self.activation.append(
activation_function(activation_type, in_size=x_size, **kwargs)
)
self.activation = nn.ModuleList(self.activation)
self.activation_output = activation_function(
activation_output, in_size=x_size, **kwargs
)
[docs]
def forward(
self,
x: torch.Tensor,
y: torch.Tensor | None = None,
with_last_layer: bool = True,
) -> torch.Tensor:
"""Apply the modulated network to the inputs.
Args:
x: Main input data of shape (batch_size, x_size), or if y is None,
concatenated input of shape (batch_size, x_size + y_size)
y: Modulation input data of shape (batch_size, y_size). Optional if
x contains both x and y concatenated.
with_last_layer: Whether to apply the final output layer
Returns:
The result of the network of shape (batch_size, out_size)
"""
# Handle both calling conventions: model(x, y) or model(inputs)
if y is None:
# inputs is concatenated [x, y]
x_data = x[..., : self.x_size]
y_data = x[..., self.x_size :]
else:
x_data = x
y_data = y
# Apply modulated hidden layers
for hidden_layer, activation, gamma_net, b_net in zip(
self.hidden_layers, self.activation, self.gamma_networks, self.b_networks
):
# Compute modulation factors
gamma = gamma_net(y_data) # Shape: (batch_size, layer_size)
b = b_net(y_data) # Shape: (batch_size, layer_size)
# Apply modulated transformation: gamma * (W * x) + b
x_data = gamma * hidden_layer(x_data) + b
x_data = activation(x_data)
# Apply output layer
if with_last_layer:
x_data = self.activation_output(self.output_layer(x_data))
return x_data
def __str__(self) -> str:
"""String representation of the model.
Returns:
A string describing the Modulation MLP network.
"""
return (
f"GenericModulationMLP with x_size={self.x_size}, y_size={self.y_size}, "
f"layer_sizes={self.layer_sizes}, "
f"modulation_layers={len(self.gamma_networks)}"
)