Source code for scimba_torch.numerical_solvers.pinn_preconditioners.energy_ng

"""Preconditioner for pinns."""

from typing import cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    ACCEPTED_PDE_TYPES,
    TYPE_DICT_OF_VMAPS,
    FunctionalOperator,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn


[docs] class EnergyNaturalGradientPreconditioner(MatrixPreconditionerPinn): """Energy-based natural gradient preconditioner. Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. **kwargs: Additional keyword arguments. Keyword Args: matrix_regularization (float): Regularization parameter for the preconditioning matrix (default: 1e-6). `adaptive_matrix_regularization` (:code:`bool`, default=False): If True, adaptively adjusts the regularization parameter during training (à la Levenberg-Marquardt). If False, uses the fixed regularization parameter specified by `matrix_regularization`. `adaptive_matrix_regularization_increase` (:code:`float`, default=10.0): Factor by which to increase the regularization parameter if the line search fails to find a suitable step size. `adaptive_matrix_regularization_decrease` (:code:`float`, default=0.5): Factor by which to decrease the regularization parameter if the line search quickly succeeds in finding a suitable step size. `adaptive_matrix_regularization_min` (:code:`float`, default=1e-12): Minimum value for the regularization parameter when using adaptive adjustment. use_lstsq (bool): Whether to use least squares solver for computing the preconditioner (default: True). If False, uses direct matrix inversion. """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6) self.adaptive_matrix_regularization = kwargs.get( "adaptive_matrix_regularization", False ) self.adaptive_matrix_regularization_increase = kwargs.get( "adaptive_matrix_regularization_increase", 10.0 ) self.adaptive_matrix_regularization_decrease = kwargs.get( "adaptive_matrix_regularization_decrease", 0.5 ) self.adaptive_matrix_regularization_min = kwargs.get( "adaptive_matrix_regularization_min", 1e-12 ) self.initial_matrix_regularization = self.matrix_regularization self.use_lstsq = kwargs.get("use_lstsq", True)
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() Phi = self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ) if len(self.in_weights) == 1: # apply the same weights to all labels for key in self.in_weights: # dummy loop Phi[:, :, :] *= self.in_weights[key] else: # apply weights for each labels for key in self.in_weights: Phi[labels == key, :, :] *= self.in_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.in_weight * M
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ) if len(self.bc_weights) == 1: # apply the same weights to all labels for key in self.bc_weights: # dummy loop Phi[:, :, :] *= self.bc_weights[key] else: # apply weights for each labels for key in self.bc_weights: Phi[labels == key, :, :] *= self.bc_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.bc_weight * M
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ) if len(self.ic_weights) == 1: # apply the same weights to all labels for key in self.ic_weights: # dummy loop Phi[:, :, :] *= self.ic_weights[key] else: # apply weights for each labels for key in self.ic_weights: Phi[labels == key, :, :] *= self.ic_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.ic_weight * M
[docs] def compute_preconditioning_matrix_dl( self, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Computes the Gram matrix of the network for the given input tensors. Args: *args: Input tensors for computing the Gram matrix. **kwargs: Additional keyword arguments. Returns: The computed Gram matrix. """ N = args[0].shape[0] jacobian = self.space.jacobian(*args) M = torch.einsum( "ijk,ilk->jl", jacobian, jacobian ) / N + self.matrix_regularization * torch.eye(self.space.ndof) return M
[docs] def compute_full_preconditioning_matrix( self, data: tuple | dict, **kwargs ) -> torch.Tensor: """Compute the full preconditioning matrix. Include contributions from the main operator, boundary conditions, and initial conditions. Args: data: Input data, either as a tuple or a dictionary. **kwargs: Additional keyword arguments. Returns: The full preconditioning matrix. """ M = self.get_preconditioning_matrix(data, **kwargs) if self.has_bc: M += self.get_preconditioning_matrix_bc(data, **kwargs) if self.has_ic: M += self.get_preconditioning_matrix_ic(data, **kwargs) for index, coeff in enumerate(self.dl_weights): M += coeff * self.get_preconditioning_matrix_dl( self.args_for_dl[index], **kwargs ) return M
def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ M = self.compute_full_preconditioning_matrix(data, **kwargs) if self.use_lstsq: preconditioned_grads = torch.linalg.lstsq(M, grads).solution else: preconditioned_grads = torch.linalg.solve(M, grads) return preconditioned_grads
[docs] def update_matrix_regularization(self, n_steps: int, loss_has_decreased: bool): """Updates the regularization parameter for the preconditioning matrix. This method adaptively adjusts the regularization parameter during training based on the number of line search steps taken and whether the loss has decreased. Args: n_steps: The number of line search steps taken. loss_has_decreased: Whether the loss has decreased after the line search. """ if self.adaptive_matrix_regularization: if loss_has_decreased: if n_steps <= 1: self.matrix_regularization *= ( self.adaptive_matrix_regularization_decrease ) else: self.matrix_regularization *= ( self.adaptive_matrix_regularization_increase ) self.matrix_regularization = max( self.matrix_regularization, self.adaptive_matrix_regularization_min )
[docs] def reset_matrix_regularization(self): """Resets the regularization parameter to its initial value.""" self.matrix_regularization = self.initial_matrix_regularization