Source code for scimba_torch.numerical_solvers.preconditioner_projector

"""Preconditioner projectors and their components."""

from abc import abstractmethod

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.abstract_preconditioner import (
    AbstractPreconditioner,
)
from scimba_torch.utils.scimba_tensors import LabelTensor


[docs] class MatrixPreconditionerProjector(AbstractPreconditioner): """Abstract base class for matrix-based preconditioner projectors. This class provides a structure for implementing preconditioners that use a matrix to precondition the gradients in projection problems. Args: space: The approximation space where the projection takes place. **kwargs: Additional keyword arguments for configuring the preconditioner. """
[docs] @abstractmethod def compute_preconditioning_matrix( self, *args: LabelTensor, **kwargs ) -> torch.Tensor: """Computes the preconditioning matrix. Args: *args: Input tensors for computing the preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The computed preconditioning matrix. """
[docs] def get_preconditioning_matrix( self, data: tuple[LabelTensor, ...], **kwargs ) -> torch.Tensor: """Retrieves the preconditioning matrix based on the input data. Args: data: The input data for computing the preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ if self.space.type_space == "flow": args = [data[0], data[1]] # only compute the matrix on the data at time t^n elif self.space.type_space == "space": args = [data[0], data[1]] elif self.space.type_space == "phase_space": args = [data[0], data[1], data[2]] # raise NotImplementedError("phase_space") else: args = [data[0], data[1], data[2]] return self.compute_preconditioning_matrix(*args, **kwargs)
[docs] def get_preconditioning_matrix_bc(self, data: tuple, **kwargs) -> torch.Tensor: """Retrieves the BC preconditioning matrix based on the input data. Args: data: The input data for computing the boundary condition preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ if self.space.type_space == "space": args = [data[2], data[4]] # do not need the normals elif self.space.type_space == "phase_space": args = [data[3], data[4], data[6]] # do not need the normals # raise NotImplementedError("phase_space") else: args = [data[3], data[4], data[6]] # do not need the normals # raise NotImplementedError("time_space") return self.compute_preconditioning_matrix(*args, **kwargs)
[docs] class EnergyNaturalGradientPreconditionerProjector(MatrixPreconditionerProjector): """Energy natural gradient preconditioner projector. This class implements a preconditioner using the energy natural gradient method. Args: space: The approximation space where the projection takes place. **kwargs: Additional keyword arguments for configuring the preconditioner. """ def __init__(self, space: AbstractApproxSpace, **kwargs): super().__init__(space, **kwargs) self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6)
[docs] def metric_matrix(self, *args: LabelTensor, **kwargs) -> torch.Tensor: """Computes the metric matrix for the given input tensors. Args: *args: Input tensors for computing the metric matrix. **kwargs: Additional keyword arguments. Returns: The computed metric matrix. """ N = args[0].shape[0] jacobian = self.space.jacobian(*args) M = torch.einsum( "ijk,ilk->jl", jacobian, jacobian ) / N + self.matrix_regularization * torch.eye(self.space.ndof) return M
[docs] def compute_preconditioning_matrix( self, *args: LabelTensor, **kwargs ) -> torch.Tensor: """Computes the preconditioning matrix using the metric matrix. Args: *args: Input tensors for computing the preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The computed preconditioning matrix. """ return self.metric_matrix(*args, **kwargs)
def __call__( self, epoch: int, data: tuple, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Applies the energy natural gradient preconditioner to the gradients. Args: epoch: The current epoch number. data: The data used for computing the preconditioner. grads: The gradients to precondition. res_l: The left residuals. res_r: The right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradients. """ M = self.get_preconditioning_matrix(data, **kwargs) if self.has_bc: Mb = self.get_preconditioning_matrix_bc(data, **kwargs) M += self.bc_weight * Mb preconditioned_grads = torch.linalg.lstsq(M, grads).solution return preconditioned_grads
[docs] class AnagramPreconditionerProjector(MatrixPreconditionerProjector): """Anagram preconditioner projector. This class implements a preconditioner using the Anagram method. Args: space: The approximation space where the projection takes place. **kwargs: Additional keyword arguments for configuring the preconditioner. """ def __init__(self, space: AbstractApproxSpace, **kwargs): super().__init__(space, **kwargs) self.svd_threshold = kwargs.get("svd_threshold", 1e-6) self.nb_components = kwargs.get("nb_components", 1) # if self.space.nb_unknowns > 1: # raise ValueError( # "Anagram preconditioner is only implemented for scalar problems." # )
[docs] def compute_preconditioning_matrix( self, *args: LabelTensor, **kwargs ) -> torch.Tensor: """Computes the preconditioning matrix using the Jacobian of the space. Args: *args: Input tensors for computing the preconditioning matrix. **kwargs: Additional keyword arguments. Returns: The computed preconditioning matrix. """ return self.space.jacobian(*args).squeeze(-1)
[docs] def assemble_right_member( self, data: tuple, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assembles the right member for the preconditioning. Args: data: The input data for assembling the right member. res_l: The left residuals. res_r: The right residuals. Returns: The assembled right member. """ in_tup = tuple(left - right for left, right in zip(res_l, res_r)) return torch.cat(in_tup, dim=0)
def __call__( self, epoch: int, data: tuple, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Applies the Anagram preconditioner to the gradients. Args: epoch: The current epoch number. data: The data used for computing the preconditioner. grads: The gradients to precondition. res_l: The left residuals. res_r: The right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradients. """ phi = self.get_preconditioning_matrix(data, **kwargs) if phi.ndim > 2: phi = torch.cat( tuple(phi[:, :, i, ...] for i in range(phi.shape[2])), dim=0 ) if self.has_bc: phib = self.get_preconditioning_matrix_bc(data, **kwargs) if phib.ndim > 2: phib = torch.cat( tuple(phib[:, :, i, ...] for i in range(phib.shape[2])), dim=0 ) phi = torch.cat((phi, phib), dim=0) ### compute pseudo inverse via svd... with full matrices: slower # U, Delta, Vt = torch.linalg.svd( # phi, # full_matrices=True, # ) # Vt = Vt[:, : Delta.shape[0]] # On garde seulement les r premières colonnes # U = U[:, : Delta.shape[0]] # On garde seulement les r premières colonnes ### compute pseudo inverse via svd... without full matrices: a bit faster U, Delta, Vt = torch.linalg.svd( phi, full_matrices=False, ) mask = Delta > self.svd_threshold # Mask pour ne garder que les grandes valeurs # print( # "nb sv : %d, kept: %d, max: %.2e, threshold: %.2e, relative: %.2e" # % ( # Delta.shape[0], # torch.sum(mask), # torch.max(Delta).item(), # self.svd_threshold, # torch.max(Delta).item() * self.svd_threshold, # ) # ) Delta_inv = torch.zeros_like(Delta) Delta_inv[mask] = 1.0 / Delta[mask] # Seulement les valeurs au-dessus du seuil # phi_plus = Vt @ torch.diag(Delta_inv) @ U.T #incorrect? phi_plus = Vt.T @ torch.diag(Delta_inv) @ U.T # correct # ### compute pseudo inverse with torch.linalg.pinv using atol... # # phi_plus = torch.linalg.pinv(phi, atol=self.svd_threshold) # res = res_l[0] - res_r[0] res = self.assemble_right_member( data, res_l[0 : self.nb_components], res_r[0 : self.nb_components] ) if self.has_bc: resb = self.assemble_right_member( data, res_l[self.nb_components : 2 * self.nb_components], res_r[self.nb_components : 2 * self.nb_components], ) res = torch.cat((res, resb), dim=0) preconditioned_grads = phi_plus @ res return preconditioned_grads