Source code for scimba_torch.numerical_solvers.preconditioner_deep_ritz

"""Deep Ritz preconditioners."""

from typing import cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    TYPE_ARGS,
    TYPE_DICT_OF_VMAPS,
    TYPE_FUNC_ARGS,
    TYPE_VMAPS,
    FunctionalOperator,
)
from scimba_torch.numerical_solvers.preconditioner_solvers import (
    MatrixPreconditionerSolver,
    _mjactheta,
    _transpose_i_j,
)
from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
    RitzFormEllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
    DivAGradUPDE,
)
from scimba_torch.utils.scimba_tensors import LabelTensor

TYPE_DATA = tuple[LabelTensor, ...] | dict[str, tuple[LabelTensor, ...]]

ACCEPTED_PDE_TYPES = RitzFormEllipticPDE | DivAGradUPDE


[docs] class MatrixPreconditionerDeepRitz(MatrixPreconditionerSolver): """Matrix-based preconditioner for pinns. Args: space: The approximation space. pde: The PDE to be solved. **kwargs: Additional keyword arguments. Raises: AttributeError: If the input PDE does not have the required attributes. NotImplementedError: when used with data loss. """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) name = "energy_matrix" if not (hasattr(self.pde, name)): raise AttributeError("input PDE must have an attribute %s" % name) assert hasattr(self.pde, name) self.energy_matrix = getattr(self.pde, name) if not callable(self.energy_matrix): raise AttributeError("attribute %s of input PDE must be a method" % name) self.operator_bc: None | FunctionalOperator = None if self.has_bc: self.operator_bc = FunctionalOperator(self.pde, "functional_operator_bc") self.operator_ic: None | FunctionalOperator = None if self.has_ic: self.operator_ic = FunctionalOperator(self.pde, "functional_operator_ic") if self.nb_dl > 0: raise NotImplementedError( "DeepRitz preconditioners + data loss is not implemented" )
[docs] def vectorize_along_physical_variables(self, func: TYPE_FUNC_ARGS) -> TYPE_VMAPS: """Vectorizes a function along physical variables based on the type of space. Args: func: The function to be vectorized. Returns: The vectorized function. Raises: NotImplementedError: If the type of space is not supported. """ scheme: tuple[int | None, ...] = tuple() if self.type_space == "space": scheme = (0, 0, None) elif self.type_space == "phase_space": raise NotImplementedError("phase_space") else: # scheme = (0, 0, 0, None) raise NotImplementedError("time_space") return torch.func.vmap(func, scheme)
[docs] class EnergyNaturalGradientPreconditioner(MatrixPreconditionerDeepRitz): """Energy Natural Gradient preconditioner for Deep Ritz methods. Args: space: The approximation space used in the Deep Ritz method. pde: The elliptic PDE represented as an instance of RitzForm_Elliptic **kwargs: Additional keyword arguments. Raises: ValueError: If the number of unknowns in the space is greater than 1. Keyword Args: matrix_regularization (float): Regularization parameter for the preconditioning matrix (default: 1e-6). `adaptive_matrix_regularization` (:code:`bool`, default=False): If True, adaptively adjusts the regularization parameter during training (à la Levenberg-Marquardt). If False, uses the fixed regularization parameter specified by `matrix_regularization`. `adaptive_matrix_regularization_increase` (:code:`float`, default=10.0): Factor by which to increase the regularization parameter if the line search fails to find a suitable step size. `adaptive_matrix_regularization_decrease` (:code:`float`, default=0.5): Factor by which to decrease the regularization parameter if the line search quickly succeeds in finding a suitable step size. `adaptive_matrix_regularization_min` (:code:`float`, default=1e-12): Minimum value for the regularization parameter when using adaptive adjustment. use_lstsq (bool): Whether to use least squares solver for computing the preconditioner (default: True). If False, uses direct matrix inversion. """ def __init__( self, space: AbstractApproxSpace, pde: RitzFormEllipticPDE | DivAGradUPDE, # is_operator_linear: bool = False, **kwargs, ): super().__init__(space, pde, **kwargs) # self.is_operator_linear = is_operator_linear self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6) self.adaptive_matrix_regularization = kwargs.get( "adaptive_matrix_regularization", False ) self.adaptive_matrix_regularization_increase = kwargs.get( "adaptive_matrix_regularization_increase", 10.0 ) self.adaptive_matrix_regularization_decrease = kwargs.get( "adaptive_matrix_regularization_decrease", 0.5 ) self.adaptive_matrix_regularization_min = kwargs.get( "adaptive_matrix_regularization_min", 1e-12 ) self.initial_matrix_regularization = self.matrix_regularization self.use_lstsq = kwargs.get("use_lstsq", True) if self.space.nb_unknowns > 1: raise ValueError( "EnergyNaturalGradient preconditioner is only implemented for scalar " "problems." ) self.vectorized_Phi = self.vectorize_along_physical_variables( self.eval_and_gradx_and_jactheta ) self.vectorized_Phi_bc: None | TYPE_DICT_OF_VMAPS = None if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.linear_Phi_bc = self.operator_bc.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_bc.apply_to_func(self.eval_func), ) self.vectorized_Phi_bc = self.vectorize_along_physical_variables_bc( self.linear_Phi_bc )
[docs] def eval_and_gradx(self, *args: TYPE_ARGS): """Evaluate the function and compute its gradient. Args: *args: Input arguments where the last argument is the parameters of the network. Returns: A tensor containing the function evaluation and its gradient. """ return torch.func.jacrev(self.eval_func, 0)(*args)
[docs] def eval_and_gradx_and_jactheta(self, *args: TYPE_ARGS): """Evaluate the function, compute its gradient, and the Jacobian. Args: *args: Input arguments where the last argument is the parameters of the network. Returns: A tensor containing the function evaluation, its gradient, and the Jacobian """ return _transpose_i_j(-1, -2, _mjactheta)(self.eval_and_gradx, *args)
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix. Args: labels: Tensor of labels corresponding to the input data. *args: Additional arguments required for computing the matrix. **kwargs: Additional keyword arguments. Returns: The computed preconditioning matrix as a tensor. """ # mu = args[-1] # N = args[0].shape[0] theta = self.get_formatted_current_theta() # if isinstance(self.space.spatial_domain, Square2D): # # # set_up_backend("torch", data_type="float64") # dimx = x.shape[1] # dimmu = mu.shape[1] # # # print("here") # def func_to_integrate( xmu ): # xarg, muarg = xmu[:, 0:dimx], xmu[:, dimx:] # Phi = test_vect(xarg, muarg, theta) # # print("Phi.shape: ", Phi.shape) # Phi2 = Phi*muarg.view(-1, 1, 1) # return torch.einsum("ijl,ikl->ijk", Phi2, Phi) # # x_domain = self.space.spatial_domain.bounds # mu_domain = torch.cat( ( torch.min(mu, 0)[0][None], # torch.max(mu, 0)[0][None] ), dim = 0).transpose(-2, -1) # integration_domain = torch.cat( (x_domain, mu_domain), dim = 0 ) # integrator = Trapezoid() # test = integrator.integrate(func_to_integrate, dim=dimx+dimmu, N=N, # integration_domain=integration_domain) # # else: Phi = self.vectorized_Phi(*args, theta) # Phi2 = Phi * mu.view(-1, 1, 1) # M = torch.einsum("ijl,ikl->jk", Phi2, Phi) / N M = self.energy_matrix({"eval_and_gradx_and_gradtheta": Phi}, args[0], args[1]) return M + self.matrix_regularization * torch.eye(self.ndof)
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: Tensor of labels corresponding to the boundary condition data. *args: Additional arguments required for computing the matrix. **kwargs: Additional keyword arguments. Returns: The computed boundary condition preconditioning matrix as a tensor. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ) # print("Phi.shape: ", Phi.shape) # if not self.is_operator_linear: # Phi = Phi[..., None] # Phi_test = self.operator_bc.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_bc_test, theta, labels, *args )[..., None] # assert torch.allclose(Phi, Phi_test) M = torch.einsum( "ijk,ilk->jl", Phi, Phi ) / N + self.matrix_regularization * torch.eye(self.ndof) return M
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: Tensor of labels corresponding to the boundary condition data. *args: Additional arguments required for computing the matrix. **kwargs: Additional keyword arguments. Returns: NotImplementedError. """ # N = args[0].shape[0] # theta = self.get_formatted_current_theta() # # self.operator_ic = cast(FunctionalOperator, self.operator_ic) # self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, # self.vectorized_Phi_ic) # # Phi = self.operator_ic.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_ic, theta, labels, *args # ) # if not self.is_operator_linear: # Phi = Phi[..., None] # # # Phi_test = self.operator_ic.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_ic_test, theta, labels, *args )[..., None] # # assert torch.allclose(Phi, Phi_test) # # M = torch.einsum( # "ijk,ilk->jl", Phi, Phi # ) / N + self.matrix_regularization * torch.eye(self.ndof) # return M return NotImplementedError
[docs] def compute_preconditioning_matrix_dl( self, *args: torch.tensor, **kwargs ) -> torch.Tensor: """Computes the Gram matrix of the network for the given input tensors. Args: *args: Input tensors for computing the Gram matrix. **kwargs: Additional keyword arguments. Returns: The computed Gram matrix. """ return NotImplementedError
# def assemble_left_member_bc(self, data: tuple | dict, res_l: tuple) # -> torch.Tensor: # # self.operator_bc = cast(FunctionalOperator, self.operator_bc) # # if len(self.operator_bc.dict_of_operators) == 1: # return res_l[0] # # args = self.get_args_for_operator_bc(data) # # return self.operator_bc.cat_tuple_of_tensors(res_l, args[0], args[1]) def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the Energy Natural Gradient preconditioner to the input gradients. Args: epoch: Current training epoch. data: Training data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals tuple. res_r: Right residuals tuple. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ M = self.get_preconditioning_matrix(data, **kwargs) if self.has_bc: Mb = self.get_preconditioning_matrix_bc(data, **kwargs) M += self.bc_weight * Mb # # if self.has_ic: # Mi = self.get_preconditioning_matrix_ic(data, **kwargs) # M += self.ic_weight[0] * Mi if self.use_lstsq: preconditioned_grads = torch.linalg.lstsq(M, grads).solution else: preconditioned_grads = torch.linalg.solve(M, grads) return preconditioned_grads
[docs] def update_matrix_regularization(self, n_steps: int, loss_has_decreased: bool): """Updates the regularization parameter for the preconditioning matrix. This method adaptively adjusts the regularization parameter during training based on the number of line search steps taken and whether the loss has decreased. Args: n_steps: The number of line search steps taken. loss_has_decreased: Whether the loss has decreased after the line search. """ if self.adaptive_matrix_regularization: if loss_has_decreased: if n_steps <= 1: self.matrix_regularization *= ( self.adaptive_matrix_regularization_decrease ) else: self.matrix_regularization *= ( self.adaptive_matrix_regularization_increase ) self.matrix_regularization = max( self.matrix_regularization, self.adaptive_matrix_regularization_min )
[docs] def reset_matrix_regularization(self): """Resets the regularization parameter to its initial value.""" self.matrix_regularization = self.initial_matrix_regularization