Source code for scimba_torch.numerical_solvers.elliptic_pde.deep_ritz

"""Deep Ritz method for solving elliptic PDEs using deep learning."""

from typing import Any

import torch
import torch.nn as nn

from scimba_torch.numerical_solvers.collocation_projector import (
    CollocationProjector,
)
from scimba_torch.numerical_solvers.preconditioner_deep_ritz import (
    # AnagramPreconditioner,
    EnergyNaturalGradientPreconditioner,
)
from scimba_torch.optimizers.losses import GenericLosses, MassLoss
from scimba_torch.optimizers.optimizers_data import OptimizerData
from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
    RitzFormEllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.laplacians import (
    Laplacian2DDirichletRitzForm,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
    DivAGradUPDE,
)
from scimba_torch.utils.scimba_tensors import MultiLabelTensor


[docs] class DeepRitzElliptic(CollocationProjector): """A deep learning-based solver for elliptic PDEs using the Ritz method. The class inherits from AbstractNonlinearProjector and provides methods for assembling the PDE system and computing the metric matrix. Args: pde: The elliptic PDE to be solved, represented as an instance of RitzForm_EllipticPDE. bc_type: The type of boundary condition to be applied ("strong" or "weak"). (default: "strong") **kwargs: Additional keyword arguments for losses and optimizers. """ def __init__( self, pde: RitzFormEllipticPDE | DivAGradUPDE, bc_type: str = "strong", **kwargs, ): super().__init__(pde.space, pde.linearform, **kwargs) self.pde = pde self.bc_type = bc_type bc_weight = kwargs.get("bc_weight", 10.0) default_list_losses = [("energy", MassLoss(), 1.0)] if self.bc_type == "weak": default_list_losses += [("bc", torch.nn.MSELoss(), bc_weight)] default_losses = GenericLosses(default_list_losses) self.losses = kwargs.get("losses", default_losses) opt_1 = { "name": "adam", "optimizer_args": {"lr": 1e-3, "betas": (0.9, 0.999)}, } default_opt = OptimizerData(opt_1) self.optimizer = kwargs.get("optimizers", default_opt) # self.xmu = None
[docs] def get_dof(self, flag_scope: str = "all", flag_format: str = "list"): """Gets the degrees of freedom for the solver, including those from the PDE. Args: flag_scope: Scope flag for getting degrees of freedom. (default: "all") flag_format: Format flag for the degrees of freedom ("list" or "tensor"). (default: "list") Returns: The degrees of freedom tensor or list. """ iterator_params = self.pde.space.get_dof(flag_scope, flag_format) if isinstance(self.pde, nn.Module): if flag_format == "list": iterator_params2 = list(self.pde.parameters()) iterator_params = iterator_params + iterator_params2 if flag_format == "tensor": iterator_params2 = torch.nn.utils.parameters_to_vector( self.pde.parameters() ) iterator_params = torch.cat((iterator_params, iterator_params2)) return iterator_params
[docs] def evaluate(self, x: torch.Tensor, mu: torch.Tensor) -> MultiLabelTensor: """Evaluates the solution at given spatial and parameter points. Args: x: Spatial points where the solution is evaluated. mu: Parameter points where the solution is evaluated. Returns: The evaluated solution as a MultiLabelTensor. """ return self.space.evaluate(x, mu)
[docs] def metric_matrix(self, **kwargs: Any) -> torch.Tensor: """Placeholder method for computing the metric matrix. Args: **kwargs: Additional keyword arguments. Returns: torch.Tensor: The metric matrix. Raises: NotImplementedError: This method must be implemented in subclasses. """ raise NotImplementedError( "You must implement the metric_matrix method for the DeepRitzElliptic " "class." )
[docs] def sample_all_vars(self, **kwargs: Any) -> tuple: """Samples all necessary variables for assembling the PDE system. Args: **kwargs: Additional keyword arguments including: - n_collocation: Number of collocation points for the PDE. Defaults to 1000. - n_bc_collocation: Number of collocation points for the boundary conditions. Defaults to 1000. Returns: A tuple containing sampled spatial points, parameter points, and, if applicable, boundary points and normals. """ n_collocation = kwargs.get("n_collocation", 1000) x, mu = self.space.integrator.sample(n_collocation) xmu = (x, mu) if self.bc_type == "weak": n_bc_collocation = kwargs.get("n_bc_collocation", 1000) xnbc, mubc = self.space.integrator.bc_sample(n_bc_collocation, index_bc=0) xbc, nbc = xnbc[0], xnbc[1] xmu = xmu + (xbc, nbc, mubc) return xmu
[docs] def assembly_post_sampling(self, xmu: tuple, **kwargs) -> tuple: """Assembles the PDE system after sampling all necessary variables. Args: xmu: A tuple containing sampled spatial points, parameter points, and, if applicable, boundary points and normals. **kwargs: Additional keyword arguments. Returns: A tuple containing the assembled system of equations (Lo, f). """ x, mu = xmu[0], xmu[1] w = self.space.evaluate(x, mu) Lo = self.pde.quadraticform(w, x, mu) ## L is a tuple f = self.pde.linearform(w, x, mu) ## f is a tuple if not isinstance(Lo, tuple): assert Lo.shape[1] == 1, ( "you must reward a tuple of residual (batch,1) or one residual tensor " "(batch,1)" ) Lo = (Lo,) if not isinstance(f, tuple): assert f.shape[1] == 1, ( "you must reward a tuple of residual (batch,1) or one residual tensor " "(batch,1)" ) f = (f,) if self.bc_type == "weak": xbc, nbc, mubc = xmu[2], xmu[3], xmu[4] w = self.space.evaluate(xbc, mubc) Lbc = self.pde.bc_operator(w, xbc, nbc, mubc) ## Lbc is a tuple fbc = self.pde.bc_rhs(w, xbc, nbc, mubc) ## Lbc is a tuple if not isinstance(Lbc, tuple): assert Lbc.shape[1] == 1, ( "you must reward a tuple of residual (batch,1) or one residual " "tensor (batch,1)" ) Lbc = (Lbc,) if not isinstance(fbc, tuple): assert fbc.shape[1] == 1, ( "you must reward a tuple of residual (batch,1) or one residual " "tensor (batch,1)" ) fbc = (fbc,) Lo = Lo + Lbc f = f + fbc return Lo, f
[docs] def assembly(self, **kwargs: Any) -> tuple: """Assembles the system of equations for the PDE. (and weak boundary conditions if needed). Args: **kwargs: Additional keyword arguments including: - n_collocation (int): Number of collocation points for the PDE. Defaults to 1000. - n_bc_collocation (int): Number of collocation points for the boundary conditions. Defaults to 1000. Returns: A tuple containing the assembled system of equations (Lo, f). """ xmu = self.sample_all_vars(**kwargs) return self.assembly_post_sampling(xmu, **kwargs)
[docs] class NaturalGradientDeepRitzElliptic(DeepRitzElliptic): """A Deep Ritz solver for elliptic PDEs using natural gradient descent. Args: pde: The elliptic PDE to be solved, represented as an instance of RitzForm_EllipticPDE. bc_type: The type of boundary condition to be applied ("strong" or "weak"). (default: "strong") **kwargs: Additional keyword arguments for losses, optimizers, and preconditioners. """ def __init__( self, pde: Laplacian2DDirichletRitzForm | DivAGradUPDE, bc_type: str = "strong", **kwargs, ): super().__init__(pde, bc_type, **kwargs) self.default_lr: float = kwargs.get("default_lr", 1e-2) opt_1 = { "name": "sgd", "optimizer_args": {"lr": self.default_lr}, } self.optimizer = OptimizerData(opt_1) self.bool_linesearch = True self.type_linesearch = kwargs.get("type_linesearch", "armijo") # self.type_linesearch = kwargs.get("type_linesearch", "logarithmic_grid") self.data_linesearch = kwargs.get("data_linesearch", {}) self.data_linesearch.setdefault("M", 10) self.data_linesearch.setdefault("interval", [0.0, 2.0]) self.data_linesearch.setdefault("log_basis", 2.0) self.data_linesearch.setdefault("n_step_max", 10) self.data_linesearch.setdefault( "alpha", 0.01, ) self.data_linesearch.setdefault( "beta", 0.5, ) self.bool_preconditioner = True self.preconditioner = EnergyNaturalGradientPreconditioner( pde.space, pde, # is_operator_linear=False, has_bc=(bc_type == "weak"), # in_lhs_name = "functional_quadraticform", **kwargs, ) self.nb_epoch_preconditioner_computing = 1 self.projection_data = {"nonlinear": True, "linear": False, "nb_step": 1}