Source code for scimba_torch.numerical_solvers.preconditioner_deep_ritz

"""Deep Ritz preconditioners."""

from typing import cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    # vectorize_dict_of_func,
    # TYPE_THETA,
    TYPE_ARGS,
    # TYPE_DICT_OF_FUNC_ARGS,
    TYPE_DICT_OF_VMAPS,
    TYPE_FUNC_ARGS,
    TYPE_VMAPS,
    FunctionalOperator,
)

# from scimba_torch.domain.meshless_domain.domain_2d import Square2D
# from torchquad import set_up_backend  # Necessary to enable GPU support
# from torchquad import Trapezoid, Gaussian, Simpson
from scimba_torch.numerical_solvers.preconditioner_solvers import (
    MatrixPreconditionerSolver,
    _mjactheta,
    _transpose_i_j,
)
from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
    RitzFormEllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
    DivAGradUPDE,
)
from scimba_torch.utils.scimba_tensors import LabelTensor

TYPE_DATA = tuple[LabelTensor, ...] | dict[str, tuple[LabelTensor, ...]]

ACCEPTED_PDE_TYPES = RitzFormEllipticPDE | DivAGradUPDE


[docs] class MatrixPreconditionerDeepRitz(MatrixPreconditionerSolver): """Matrix-based preconditioner for pinns. Args: space: The approximation space. pde: The PDE to be solved. **kwargs: Additional keyword arguments. Raises: AttributeError: If the input PDE does not have the required attributes. """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) name = "energy_matrix" if not (hasattr(self.pde, name)): raise AttributeError("input PDE must have an attribute %s" % name) assert hasattr(self.pde, name) self.energy_matrix = getattr(self.pde, name) if not callable(self.energy_matrix): raise AttributeError("attribute %s of input PDE must be a method" % name) self.operator_bc: None | FunctionalOperator = None if self.has_bc: self.operator_bc = FunctionalOperator(self.pde, "functional_operator_bc") self.operator_ic: None | FunctionalOperator = None if self.has_ic: self.operator_ic = FunctionalOperator(self.pde, "functional_operator_ic")
[docs] def vectorize_along_physical_variables(self, func: TYPE_FUNC_ARGS) -> TYPE_VMAPS: """Vectorizes a function along physical variables based on the type of space. Args: func: The function to be vectorized. Returns: The vectorized function. Raises: NotImplementedError: If the type of space is not supported. """ scheme: tuple[int | None, ...] = tuple() if self.type_space == "space": scheme = (0, 0, None) elif self.type_space == "phase_space": raise NotImplementedError("phase_space") else: # scheme = (0, 0, 0, None) raise NotImplementedError("time_space") return torch.func.vmap(func, scheme)
[docs] class EnergyNaturalGradientPreconditioner(MatrixPreconditionerDeepRitz): """Energy Natural Gradient preconditioner for Deep Ritz methods. Args: space: The approximation space used in the Deep Ritz method. pde: The elliptic PDE represented as an instance of RitzForm_Elliptic **kwargs: Additional keyword arguments. Raises: ValueError: If the number of unknowns in the space is greater than 1. """ def __init__( self, space: AbstractApproxSpace, pde: RitzFormEllipticPDE | DivAGradUPDE, # is_operator_linear: bool = False, **kwargs, ): super().__init__(space, pde, **kwargs) # self.is_operator_linear = is_operator_linear self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6) if self.space.nb_unknowns > 1: raise ValueError( "EnergyNaturalGradient preconditioner is only implemented for scalar " "problems." ) self.vectorized_Phi = self.vectorize_along_physical_variables( self.eval_and_gradx_and_jactheta ) self.vectorized_Phi_bc: None | TYPE_DICT_OF_VMAPS = None if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.linear_Phi_bc = self.operator_bc.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_bc.apply_to_func(self.eval_func), ) self.vectorized_Phi_bc = self.vectorize_along_physical_variables_bc( self.linear_Phi_bc )
[docs] def eval_and_gradx(self, *args: TYPE_ARGS): """Evaluate the function and compute its gradient. Args: *args: Input arguments where the last argument is the parameters of the network. Returns: A tensor containing the function evaluation and its gradient. """ return torch.func.jacrev(self.eval_func, 0)(*args)
[docs] def eval_and_gradx_and_jactheta(self, *args: TYPE_ARGS): """Evaluate the function, compute its gradient, and the Jacobian. Args: *args: Input arguments where the last argument is the parameters of the network. Returns: A tensor containing the function evaluation, its gradient, and the Jacobian """ return _transpose_i_j(-1, -2, _mjactheta)(self.eval_and_gradx, *args)
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix. Args: labels: Tensor of labels corresponding to the input data. *args: Additional arguments required for computing the matrix. **kwargs: Additional keyword arguments. Returns: The computed preconditioning matrix as a tensor. """ # mu = args[-1] # N = args[0].shape[0] theta = self.get_formatted_current_theta() # if isinstance(self.space.spatial_domain, Square2D): # # # set_up_backend("torch", data_type="float64") # dimx = x.shape[1] # dimmu = mu.shape[1] # # # print("here") # def func_to_integrate( xmu ): # xarg, muarg = xmu[:, 0:dimx], xmu[:, dimx:] # Phi = test_vect(xarg, muarg, theta) # # print("Phi.shape: ", Phi.shape) # Phi2 = Phi*muarg.view(-1, 1, 1) # return torch.einsum("ijl,ikl->ijk", Phi2, Phi) # # x_domain = self.space.spatial_domain.bounds # mu_domain = torch.cat( ( torch.min(mu, 0)[0][None], # torch.max(mu, 0)[0][None] ), dim = 0).transpose(-2, -1) # integration_domain = torch.cat( (x_domain, mu_domain), dim = 0 ) # integrator = Trapezoid() # test = integrator.integrate(func_to_integrate, dim=dimx+dimmu, N=N, # integration_domain=integration_domain) # # else: Phi = self.vectorized_Phi(*args, theta) # Phi2 = Phi * mu.view(-1, 1, 1) # M = torch.einsum("ijl,ikl->jk", Phi2, Phi) / N M = self.energy_matrix({"eval_and_gradx_and_gradtheta": Phi}, args[0], args[1]) return M + self.matrix_regularization * torch.eye(self.ndof)
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: Tensor of labels corresponding to the boundary condition data. *args: Additional arguments required for computing the matrix. **kwargs: Additional keyword arguments. Returns: The computed boundary condition preconditioning matrix as a tensor. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ) # print("Phi.shape: ", Phi.shape) # if not self.is_operator_linear: # Phi = Phi[..., None] # Phi_test = self.operator_bc.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_bc_test, theta, labels, *args )[..., None] # assert torch.allclose(Phi, Phi_test) M = torch.einsum( "ijk,ilk->jl", Phi, Phi ) / N + self.matrix_regularization * torch.eye(self.ndof) return M
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: Tensor of labels corresponding to the boundary condition data. *args: Additional arguments required for computing the matrix. **kwargs: Additional keyword arguments. Returns: NotImplementedError. """ # N = args[0].shape[0] # theta = self.get_formatted_current_theta() # # self.operator_ic = cast(FunctionalOperator, self.operator_ic) # self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, # self.vectorized_Phi_ic) # # Phi = self.operator_ic.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_ic, theta, labels, *args # ) # if not self.is_operator_linear: # Phi = Phi[..., None] # # # Phi_test = self.operator_ic.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_ic_test, theta, labels, *args )[..., None] # # assert torch.allclose(Phi, Phi_test) # # M = torch.einsum( # "ijk,ilk->jl", Phi, Phi # ) / N + self.matrix_regularization * torch.eye(self.ndof) # return M return NotImplementedError
# def assemble_left_member_bc(self, data: tuple | dict, res_l: tuple) # -> torch.Tensor: # # self.operator_bc = cast(FunctionalOperator, self.operator_bc) # # if len(self.operator_bc.dict_of_operators) == 1: # return res_l[0] # # args = self.get_args_for_operator_bc(data) # # return self.operator_bc.cat_tuple_of_tensors(res_l, args[0], args[1]) def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the Energy Natural Gradient preconditioner to the input gradients. Args: epoch: Current training epoch. data: Training data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals tuple. res_r: Right residuals tuple. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ M = self.get_preconditioning_matrix(data, **kwargs) if self.has_bc: Mb = self.get_preconditioning_matrix_bc(data, **kwargs) M += self.bc_weight * Mb # # if self.has_ic: # Mi = self.get_preconditioning_matrix_ic(data, **kwargs) # M += self.ic_weight * Mi preconditioned_grads = torch.linalg.lstsq(M, grads).solution return preconditioned_grads