Source code for scimba_torch.numerical_solvers.preconditioner_projector

"""Preconditioner projectors and their components."""

from abc import abstractmethod

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.abstract_preconditioner import (
    AbstractPreconditioner,
)
from scimba_torch.utils.scimba_tensors import LabelTensor


[docs] class MatrixPreconditionerProjector(AbstractPreconditioner): """Abstract base class for matrix-based preconditioner projectors. This class provides a structure for implementing preconditioners that use a matrix to precondition the gradients in projection problems. Args: space: The approximation space where the projection takes place. **kwargs: Additional keyword arguments for configuring the preconditioner. """
[docs] @abstractmethod def compute_preconditioning_matrix( self, *args: LabelTensor, **kwargs ) -> torch.Tensor: """Computes the preconditioning matrix. Args: *args: Input tensors for computing the preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The computed preconditioning matrix. """
[docs] def get_preconditioning_matrix( self, data: tuple[LabelTensor, ...], **kwargs ) -> torch.Tensor: """Retrieves the preconditioning matrix based on the input data. Args: data: The input data for computing the preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ if self.space.type_space == "flow": args = [data[0], data[1]] # only compute the matrix on the data at time t^n elif self.space.type_space == "space": args = [data[0], data[1]] elif self.space.type_space == "phase_space": args = [data[0], data[1], data[2]] # raise NotImplementedError("phase_space") else: args = [data[0], data[1], data[2]] return self.compute_preconditioning_matrix(*args, **kwargs)
[docs] def get_preconditioning_matrix_bc(self, data: tuple, **kwargs) -> torch.Tensor: """Retrieves the BC preconditioning matrix based on the input data. Args: data: The input data for computing the boundary condition preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ if self.space.type_space == "space": args = [data[2], data[4]] # do not need the normals elif self.space.type_space == "phase_space": args = [data[3], data[4], data[6]] # do not need the normals # raise NotImplementedError("phase_space") else: args = [data[3], data[4], data[6]] # do not need the normals # raise NotImplementedError("time_space") return self.compute_preconditioning_matrix(*args, **kwargs)
[docs] class EnergyNaturalGradientPreconditionerProjector(MatrixPreconditionerProjector): """Energy natural gradient preconditioner projector. This class implements a preconditioner using the energy natural gradient method. Args: space: The approximation space where the projection takes place. **kwargs: Additional keyword arguments for configuring the preconditioner. Keyword Args: `matrix_regularization` (:code:`float`): Regularization parameter for the preconditioning matrix (default: 1e-6). `adaptive_matrix_regularization` (:code:`bool`, default=False): If True, adaptively adjusts the regularization parameter during training (à la Levenberg-Marquardt). If False, uses the fixed regularization parameter specified by `matrix_regularization`. `adaptive_matrix_regularization_increase` (:code:`float`, default=10.0): Factor by which to increase the regularization parameter if the line search fails to find a suitable step size. `adaptive_matrix_regularization_decrease` (:code:`float`, default=0.5): Factor by which to decrease the regularization parameter if the line search quickly succeeds in finding a suitable step size. `adaptive_matrix_regularization_min` (:code:`float`, default=1e-12): Minimum value for the regularization parameter when using adaptive adjustment. use_lstsq (bool): Whether to use least squares solver for computing the preconditioner (default: True). If False, uses direct matrix inversion. """ def __init__(self, space: AbstractApproxSpace, **kwargs): super().__init__(space, **kwargs) self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6) self.adaptive_matrix_regularization = kwargs.get( "adaptive_matrix_regularization", False ) self.adaptive_matrix_regularization_increase = kwargs.get( "adaptive_matrix_regularization_increase", 10.0 ) self.adaptive_matrix_regularization_decrease = kwargs.get( "adaptive_matrix_regularization_decrease", 0.5 ) self.adaptive_matrix_regularization_min = kwargs.get( "adaptive_matrix_regularization_min", 1e-12 ) self.initial_matrix_regularization = self.matrix_regularization self.use_lstsq = kwargs.get("use_lstsq", True)
[docs] def metric_matrix(self, *args: LabelTensor, **kwargs) -> torch.Tensor: """Computes the metric matrix for the given input tensors. Args: *args: Input tensors for computing the metric matrix. **kwargs: Additional keyword arguments. Returns: The computed metric matrix. """ N = args[0].shape[0] jacobian = self.space.jacobian(*args) reg = self.matrix_regularization * torch.eye(self.space.ndof) return torch.einsum("ijk,ilk->jl", jacobian, jacobian) / N + reg
[docs] def compute_preconditioning_matrix( self, *args: LabelTensor, **kwargs ) -> torch.Tensor: """Computes the preconditioning matrix using the metric matrix. Args: *args: Input tensors for computing the preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The computed preconditioning matrix. """ return self.metric_matrix(*args, **kwargs)
def __call__( self, epoch: int, data: tuple, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Applies the energy natural gradient preconditioner to the gradients. Args: epoch: The current epoch number. data: The data used for computing the preconditioner. grads: The gradients to precondition. res_l: The left residuals. res_r: The right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradients. """ M = self.get_preconditioning_matrix(data, **kwargs) if self.has_bc: Mb = self.get_preconditioning_matrix_bc(data, **kwargs) M += self.bc_weight * Mb if self.use_lstsq: preconditioned_grads = torch.linalg.lstsq(M, grads).solution else: preconditioned_grads = torch.linalg.solve(M, grads) return preconditioned_grads
[docs] def update_matrix_regularization(self, n_steps: int, loss_has_decreased: bool): """Updates the regularization parameter for the preconditioning matrix. This method adaptively adjusts the regularization parameter during training based on the number of line search steps taken and whether the loss has decreased. Args: n_steps: The number of line search steps taken. loss_has_decreased: Whether the loss has decreased after the line search. """ if self.adaptive_matrix_regularization: if loss_has_decreased: if n_steps <= 1: self.matrix_regularization *= ( self.adaptive_matrix_regularization_decrease ) else: self.matrix_regularization *= ( self.adaptive_matrix_regularization_increase ) self.matrix_regularization = max( self.matrix_regularization, self.adaptive_matrix_regularization_min )
[docs] def reset_matrix_regularization(self): """Resets the regularization parameter to its initial value.""" self.matrix_regularization = self.initial_matrix_regularization
[docs] class AnagramPreconditionerProjector(MatrixPreconditionerProjector): """Anagram preconditioner projector. This class implements a preconditioner using the Anagram method. Args: space: The approximation space where the projection takes place. **kwargs: Additional keyword arguments for configuring the preconditioner. Keyword Args: `svd_threshold` (:code:`float`): Threshold for singular value decomposition (default: 1e-6). """ def __init__(self, space: AbstractApproxSpace, **kwargs): super().__init__(space, **kwargs) self.svd_threshold = kwargs.get("svd_threshold", 1e-6) self.nb_components = kwargs.get("nb_components", 1) # if self.space.nb_unknowns > 1: # raise ValueError( # "Anagram preconditioner is only implemented for scalar problems." # )
[docs] def compute_preconditioning_matrix( self, *args: LabelTensor, **kwargs ) -> torch.Tensor: """Computes the preconditioning matrix using the Jacobian of the space. Args: *args: Input tensors for computing the preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The computed preconditioning matrix. """ return self.space.jacobian(*args).squeeze(-1)
[docs] def assemble_right_member( self, data: tuple, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assembles the right member for the preconditioning. Args: data: The input data for assembling the right member. res_l: The left residuals. res_r: The right residuals. Returns: The assembled right member. """ in_tup = tuple(left - right for left, right in zip(res_l, res_r)) return torch.cat(in_tup, dim=0)
def __call__( self, epoch: int, data: tuple, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Applies the Anagram preconditioner to the gradients. Args: epoch: The current epoch number. data: The data used for computing the preconditioner. grads: The gradients to precondition. res_l: The left residuals. res_r: The right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradients. """ phi = self.get_preconditioning_matrix(data, **kwargs) if phi.ndim > 2: phi = torch.cat( tuple(phi[:, :, i, ...] for i in range(phi.shape[2])), dim=0 ) if self.has_bc: phib = self.get_preconditioning_matrix_bc(data, **kwargs) if phib.ndim > 2: phib = torch.cat( tuple(phib[:, :, i, ...] for i in range(phib.shape[2])), dim=0 ) phi = torch.cat((phi, phib), dim=0) ### compute pseudo inverse via svd... with full matrices: slower # U, Delta, Vt = torch.linalg.svd( # phi, # full_matrices=True, # ) # Vt = Vt[:, : Delta.shape[0]] # On garde seulement les r premières colonnes # U = U[:, : Delta.shape[0]] # On garde seulement les r premières colonnes ### compute pseudo inverse via svd... without full matrices: a bit faster U, Delta, Vt = torch.linalg.svd( phi, full_matrices=False, ) mask = Delta > self.svd_threshold # Mask pour ne garder que les grandes valeurs # print( # "nb sv : %d, kept: %d, max: %.2e, threshold: %.2e, relative: %.2e" # % ( # Delta.shape[0], # torch.sum(mask), # torch.max(Delta).item(), # self.svd_threshold, # torch.max(Delta).item() * self.svd_threshold, # ) # ) Delta_inv = torch.zeros_like(Delta) Delta_inv[mask] = 1.0 / Delta[mask] # Seulement les valeurs au-dessus du seuil # phi_plus = Vt @ torch.diag(Delta_inv) @ U.T #incorrect? phi_plus = Vt.T @ torch.diag(Delta_inv) @ U.T # correct # ### compute pseudo inverse with torch.linalg.pinv using atol... # # phi_plus = torch.linalg.pinv(phi, atol=self.svd_threshold) # res = res_l[0] - res_r[0] res = self.assemble_right_member( data, res_l[0 : self.nb_components], res_r[0 : self.nb_components] ) if self.has_bc: resb = self.assemble_right_member( data, res_l[self.nb_components : 2 * self.nb_components], res_r[self.nb_components : 2 * self.nb_components], ) res = torch.cat((res, resb), dim=0) preconditioned_grads = phi_plus @ res return preconditioned_grads