Source code for scimba_torch.approximation_space.kernelx_space

"""Defines the Kernel-based approximation space and its components."""

from typing import Any

import torch
import torch.nn as nn
from torch.func import functional_call, jacrev, vmap

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.domain.meshless_domain.base import VolumetricDomain
from scimba_torch.integration.monte_carlo import TensorizedSampler
from scimba_torch.neural_nets.coordinates_based_nets.scimba_module import ScimbaModule
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor


[docs] class GaussianKernel(nn.Module): """Gaussian kernen approximation space. Args: **kwargs: Additional arguments for the Gaussian kernel. """ def __init__(self, **kwargs: Any): super().__init__()
[docs] def forward(self, vector_diff: torch.Tensor, aniso: torch.Tensor) -> torch.Tensor: """Compute the Gaussian kernel values. Args: vector_diff: Tensor of shape (num_samples, num_centers, input_dim) representing the difference between input points and kernel centers. aniso: Tensor of shape (num_centers, input_dim, input_dim) representing the anisotropic transformation matrices for each kernel center. Returns: Tensor of shape (num_samples, num_centers) containing the Gaussian kernel values. """ distances = torch.einsum("nkd,kdd,nkd->nk", vector_diff, aniso, vector_diff) return torch.exp(-distances)
[docs] class ExponentialKernel(nn.Module): """Exponential kernel approximation space. Args: **kwargs: Additional arguments for the Exponential kernel. - beta: Shape parameter of the Exponential kernel (default: 2). """ def __init__(self, **kwargs: Any): super().__init__() self.beta = kwargs.get("beta", 2)
[docs] def forward(self, vector_diff: torch.Tensor, aniso: torch.Tensor) -> torch.Tensor: """Compute the Exponential kernel values. Args: vector_diff: Tensor of shape (num_samples, num_centers, input_dim) representing the difference between input points and kernel centers. aniso: Tensor of shape (num_centers, input_dim, input_dim) representing the anisotropic transformation matrices for each kernel center. Returns: Tensor of shape (num_samples, num_centers) containing the Exponential kernel values. """ l1_norm = torch.sum(torch.abs(vector_diff), dim=-1) return torch.exp(-((l1_norm) ** self.beta))
[docs] class MultiquadraticKernel(nn.Module): """Multiquadratic kernel approximation space. Args: **kwargs: Additional arguments for the Multiquadratic kernel. - beta: Shape parameter of the Multiquadratic kernel (default: 2). """ def __init__(self, **kwargs: Any): super().__init__() self.beta = kwargs.get("beta", 2)
[docs] def forward(self, vector_diff: torch.Tensor, aniso: torch.Tensor) -> torch.Tensor: """Compute the Multiquadratic kernel values. Args: vector_diff: Tensor of shape (num_samples, num_centers, input_dim) representing the difference between input points and kernel centers. aniso: Tensor of shape (num_centers, input_dim, input_dim) representing the anisotropic transformation matrices for each kernel center. Returns: Tensor of shape (num_samples, num_centers) containing the Multiquadratic kernel values. """ distances = torch.einsum("nkd,kdd,nkd->nk", vector_diff, aniso, vector_diff) return (1 + (distances) ** 2) ** self.beta
[docs] class KernelxSpace(AbstractApproxSpace, ScimbaModule): """A nonlinear approximation space using a neural network model. This class represents a parametric approximation space, where the solution is modeled by a neural network. It integrates functionality for evaluating the network, setting/retrieving degrees of freedom, and computing the Jacobian. Args: nb_unknowns: Number of unknowns in the approximation problem. nb_parameters: Number of parameters in the approximation problem. kernel_type: The type of kernel to use for the approximation. nb_centers: Number of centers for the kernel functions. spatial_domain: The spatial domain of the problem. integrator: Sampler used for integration over the spatial and parameter domains. **kwargs: Additional arguments passed to the neural network model. """ def __init__( self, nb_unknowns: int, nb_parameters: int, kernel_type: nn.Module, nb_centers: int, spatial_domain: VolumetricDomain, integrator: TensorizedSampler, **kwargs, ): # Call the initializer of ScimbaModule ScimbaModule.__init__( self, in_size=spatial_domain.dim + nb_parameters, # problem out_size=nb_unknowns, **kwargs, ) # Call the initializer of AbstractApproxSpace AbstractApproxSpace.__init__(self, nb_unknowns) self.spatial_domain: VolumetricDomain = spatial_domain self.integrator: TensorizedSampler = integrator self.kernel_type: nn.Module = kernel_type self.anisotropic: bool = kwargs.get("anisotropic", False) self.type_space: str = "space" def default_pre_processing(x: LabelTensor, mu: LabelTensor): return torch.cat([x.x, mu.x], dim=1) def default_post_processing( inputs: torch.Tensor, x: LabelTensor, mu: LabelTensor ): return inputs self.pre_processing = kwargs.get("pre_processing", default_pre_processing) self.post_processing = kwargs.get("post_processing", default_post_processing) # self.centers = next(self.integrator.sample(nb_centers)) self.centers_x, self.centers_mu = self.integrator.sample(nb_centers) self.centers: torch.nn.Parameter = nn.Parameter( torch.cat([self.centers_x.x, self.centers_mu.x], dim=1) ) #: Centers of the kernel functions, initialized as learnable parameters. self.kernel = kernel_type(**kwargs) #: Parameter for the kernel functions, controlling their shape and behavior. self.beta: float = kwargs.get("beta", 2) # Création d'une liste de matrice anistrope E pour le noyau anisotrope # E = M*M^T+ Id*eps M: rdm if self.anisotropic: M = torch.randn( (self.centers.shape[0], self.centers.shape[1], self.centers.shape[1]), dtype=torch.get_default_dtype(), ) E = ( M @ M.transpose(-1, -2) + torch.eye(self.centers.shape[1], dtype=torch.get_default_dtype()) * 1e-5 ) #: Anisotropic transformation matrix for the kernel functions, #: initialized as learnable parameters. self.M_aniso: torch.nn.Parameter = nn.Parameter(kwargs.get("eps", 1.0) * E) else: #: Epsilon parameters for the kernel functions, #: initialized as learnable parameters. self.eps: torch.nn.Parameter = nn.Parameter( torch.ones(nb_centers) * kwargs.get("eps", 1.0) ) self.output_layer = torch.nn.Linear(nb_centers, self.out_size) #: Total number of degrees of freedom in the network. self.ndof: int = self.get_dof(flag_format="tensor").shape[0] self.Id = torch.eye(spatial_domain.dim + nb_parameters)
[docs] def forward( self, features: torch.Tensor, with_last_layer: bool = True ) -> torch.Tensor: """Forward pass through the kernel model. Args: features: Input tensor with concatenated spatial and parameter data. with_last_layer: Whether to apply the final linear layer. Returns: Output tensor from the kernel model. """ vect_diff = torch.zeros((features.shape[0], self.centers.shape[0], 2)) vect_diff = features.unsqueeze(1) - self.centers.unsqueeze(0) if self.anisotropic: # Apply anisotropic transformation Aniso = self.M_aniso else: Aniso = self.eps[:, None, None] * self.Id basis = self.kernel(vect_diff, Aniso) # (batch_size, nb_centers) if with_last_layer: res = self.output_layer(basis) else: res = basis return res
[docs] def evaluate( self, x: LabelTensor, mu: LabelTensor, with_last_layer: bool = True ) -> MultiLabelTensor: """Evaluate the parametric model for given inputs and parameters. Args: x: Input tensor from the spatial domain. mu: Input tensor from the parameter domain. with_last_layer: Whether to apply the final linear layer. Returns: Output tensor from the neural network, wrapped with multi-label metadata. """ features = self.pre_processing(x, mu) res = self.forward(features, with_last_layer) if with_last_layer: res = self.post_processing(res, x, mu) return MultiLabelTensor(res, [x.labels, mu.labels])
[docs] def set_dof(self, theta: torch.Tensor, flag_scope: str = "all") -> None: """Sets the degrees of freedom (DoF) for the neural network. Args: theta: A vector containing the network parameters. flag_scope: The scope of parameters to return. """ self.set_parameters(theta, flag_scope)
[docs] def get_dof( self, flag_scope: str = "all", flag_format: str = "list" ) -> torch.Tensor: """Retrieves the degrees of freedom (DoF) of the neural network. Args: flag_scope: The scope of parameters to return. flag_format: The format of the returned parameters. Returns: Tensor containing the DoF of the network. """ return self.parameters(flag_scope=flag_scope, flag_format=flag_format)
[docs] def jacobian(self, x: LabelTensor, mu: LabelTensor) -> torch.Tensor: """Compute the Jacobian of the network with respect to its parameters. Args: x: Input tensor from the spatial domain. mu: Input tensor from the parameter domain. Returns: Jacobian matrix of shape `(num_samples, out_size, num_params)`. """ params = {k: v.detach() for k, v in self.named_parameters()} features = self.pre_processing(x, mu) def fnet(theta, features): return functional_call(self, theta, (features.unsqueeze(0))).squeeze(0) jac = vmap(jacrev(fnet), (None, 0))(params, features).values() jac_m = torch.cat( [j.reshape((features.shape[0], self.out_size, -1)) for j in jac], dim=-1 ) jac_mt = jac_m.transpose(1, 2) return self.post_processing(jac_mt, x, mu)