Source code for scimba_torch.numerical_solvers.abstract_preconditioner

"""Preconditioners for projectors."""

from abc import ABC, abstractmethod

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import ACCEPTED_PDE_TYPES


[docs] class AbstractPreconditioner(ABC): """Abstract class for preconditioner solvers. Args: space: The approximation space. pde: The PDE to be solved, None by default for projection. **kwargs: Additional keyword arguments: - `has_bc` (:code:`bool`): Whether the PDE has boundary conditions. (default: False) - `bc_weight` (:code:`float`): Weight for the boundary conditions. (default: 10.0) - `has_ic` (:code:`bool`): Whether the PDE has initial conditions. (default: False) - `ic_weight` (:code:`float`): Weight for the initial conditions. (default: 10.0) """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES | None = None, **kwargs, ): self.space = space self.pde = pde # used to balance the losses between in/bc equations self.in_weight = kwargs.get("in_weight", 1.0) self.bc_weight = kwargs.get("bc_weight", 10.0) self.ic_weight = kwargs.get("ic_weight", 10.0) # used to balance the losses between equations self.in_weights = kwargs.get("in_weights", 1.0) self.bc_weights = kwargs.get("bc_weights", 1.0) self.ic_weights = kwargs.get("ic_weights", 1.0) self.has_bc = kwargs.get("has_bc", False) self.has_ic = kwargs.get("has_ic", False) @abstractmethod def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Abstract method for preconditioner call. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """
[docs] def update_matrix_regularization(self, n_steps: int, loss_has_decreased: bool): """Updates the regularization parameter for the preconditioning matrix. This method can be overridden by subclasses to implement adaptive regularization strategies based on the number of line search steps taken and whether the loss has decreased. By default, it does nothing. Subclasses that implement adaptive regularization can override this method to update the regularization parameter during training. Args: n_steps: The number of line search steps taken. loss_has_decreased: Whether the loss has decreased after the line search. """
[docs] def reset_matrix_regularization(self): """Resets the regularization parameter to its initial value. By default, it does nothing. Subclasses that implement adaptive regularization can override this method to reset the regularization parameter when needed. """
[docs] class IdPreconditioner(AbstractPreconditioner): """Identity preconditioner that returns the input gradients unchanged. Args: space: The approximation space. pde: The PDE to be solved, None by default for projection. **kwargs: Additional keyword arguments. """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES | None = None, **kwargs, ): super().__init__(space, pde, **kwargs) def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Return the input gradients unchanged. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The unmodified gradient tensor. """ return grads