Source code for scimba_torch.numerical_solvers.pinn_preconditioners.energy_ng

"""Preconditioner for pinns."""

from typing import cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    ACCEPTED_PDE_TYPES,
    TYPE_DICT_OF_VMAPS,
    FunctionalOperator,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn


[docs] class EnergyNaturalGradientPreconditioner(MatrixPreconditionerPinn): """Energy-based natural gradient preconditioner. Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. **kwargs: Additional keyword arguments. Keyword Args: matrix_regularization (float): Regularization parameter for the preconditioning matrix (default: 1e-6). """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6)
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() Phi = self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ) if len(self.in_weights) == 1: # apply the same weights to all labels for key in self.in_weights: # dummy loop Phi[:, :, :] *= self.in_weights[key] else: # apply weights for each labels for key in self.in_weights: Phi[labels == key, :, :] *= self.in_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.in_weight * M
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ) if len(self.bc_weights) == 1: # apply the same weights to all labels for key in self.bc_weights: # dummy loop Phi[:, :, :] *= self.bc_weights[key] else: # apply weights for each labels for key in self.bc_weights: Phi[labels == key, :, :] *= self.bc_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.bc_weight * M
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ) if len(self.ic_weights) == 1: # apply the same weights to all labels for key in self.ic_weights: # dummy loop Phi[:, :, :] *= self.ic_weights[key] else: # apply weights for each labels for key in self.ic_weights: Phi[labels == key, :, :] *= self.ic_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.ic_weight * M
[docs] def compute_preconditioning_matrix_dl( self, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Computes the Gram matrix of the network for the given input tensors. Args: *args: Input tensors for computing the Gram matrix. **kwargs: Additional keyword arguments. Returns: The computed Gram matrix. """ N = args[0].shape[0] jacobian = self.space.jacobian(*args) M = torch.einsum( "ijk,ilk->jl", jacobian, jacobian ) / N + self.matrix_regularization * torch.eye(self.space.ndof) return M
[docs] def compute_full_preconditioning_matrix( self, data: tuple | dict, **kwargs ) -> torch.Tensor: """Compute the full preconditioning matrix. Include contributions from the main operator, boundary conditions, and initial conditions. Args: data: Input data, either as a tuple or a dictionary. **kwargs: Additional keyword arguments. Returns: The full preconditioning matrix. """ M = self.get_preconditioning_matrix(data, **kwargs) if self.has_bc: M += self.get_preconditioning_matrix_bc(data, **kwargs) if self.has_ic: M += self.get_preconditioning_matrix_ic(data, **kwargs) for index, coeff in enumerate(self.dl_weights): M += coeff * self.get_preconditioning_matrix_dl( self.args_for_dl[index], **kwargs ) return M
def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ M = self.compute_full_preconditioning_matrix(data, **kwargs) preconditioned_grads = torch.linalg.lstsq(M, grads).solution return preconditioned_grads