"""This module defines the PinnsElliptic class.
PinnsElliptic is a subclass of NonlinearGalerkinProjector designed to solve elliptic
partial differential equations (PDEs) using physics-informed neural networks (PINNs).
"""
import warnings
from abc import ABC
from typing import TYPE_CHECKING, Any, cast
import torch
import torch.nn as nn
from scimba_torch.numerical_solvers.abstract_projector import AbstractNonlinearProjector
from scimba_torch.numerical_solvers.collocation_projector import (
CollocationProjector,
)
from scimba_torch.numerical_solvers.pinn_preconditioners import (
AnagramPreconditioner,
EnergyNaturalGradientPreconditioner,
NystromNaturalGradientPreconditioner,
SketchyNaturalGradientPreconditioner,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import (
MatrixPreconditionerPinn,
_is_type_dict_of_weight,
)
from scimba_torch.optimizers.losses import DataLoss, GenericLosses
from scimba_torch.optimizers.optimizers_data import OptimizerData
from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
StrongFormEllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
LinearOrder2PDE,
)
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
NG_ALGO_NAME: str = "ng_algo"
ENG_VALUES: list[str] = ["ENG", "energy", "EnergyNaturalGradient"]
ANA_VALUES: list[str] = ["ANaGRAM", "anagram", "ANAGRAM"]
SNG_VALUES: list[str] = ["SNG", "sketchy", "Sketchy", "SketchyNaturalGradient"]
NNG_VALUES: list[str] = [
"NNG",
"nystrom",
"Nystrom",
"NystromNaturalGradient",
"nyström",
"Nyström",
"NyströmNaturalGradient",
]
def _check_and_format_weight_argument(weight: Any) -> list[float]:
"""Format weight argument.
Args:
weight: the weight argument.
Returns:
the formatted weight argument.
Raises:
TypeError: the weight argument has incorrect type
"""
if isinstance(weight, float | int):
res = [float(weight)]
elif isinstance(weight, list) and all(isinstance(wk, float | int) for wk in weight):
res = [float(w) for w in weight]
elif _is_type_dict_of_weight(weight):
res = []
for key in weight:
if isinstance(weight[key], float | int):
res += [float(weight[key])]
elif isinstance(weight[key], list):
res += [float(w) for w in weight[key]]
else:
raise TypeError(
"weight argument must be of type float | int, "
"list[float | int], or "
"OrderedDict[int, float | int | list[float | int]]"
)
return res
[docs]
class PinnsElliptic(CollocationProjector):
"""A class to solve elliptic PDEs using Physics-Informed Neural Networks (PINNs).
Args:
pde: The elliptic PDE to be solved, represented as an instance of
StrongFormEllipticPDE or LinearOrder2PDE.
bc_type: The type of boundary condition to be applied ("strong" or "weak").
(default: "strong")
**kwargs: Additional keyword arguments for losses and optimizers.
Raises:
ValueError: when the lengths of in_weights or bc_weights does not match
residual_size or bc_residual_size
when data_losses and data_weights arguments are not correct
"""
def __init__(
self,
pde: StrongFormEllipticPDE | LinearOrder2PDE,
bc_type: str = "strong",
**kwargs,
):
super().__init__(pde.space, pde.rhs, **kwargs)
self.pde = pde
self.bc_type = bc_type
self.space = pde.space
self.one_loss_by_equation = kwargs.get("one_loss_by_equation", False)
# in/bc balance
self.in_weight = kwargs.get("in_weight", 1.0)
self.bc_weight = kwargs.get("bc_weight", 10.0)
# in case of several equations/labels, balance between equations of residual
in_weights = kwargs.get("in_weights", 1.0)
self.in_weights = _check_and_format_weight_argument(in_weights)
# in case of several equations/labels, balance between equations of bc
bc_weights = kwargs.get("bc_weights", 1.0)
self.bc_weights = _check_and_format_weight_argument(bc_weights)
if self.one_loss_by_equation:
if len(self.in_weights) == 1:
self.in_weights = self.in_weights * self.pde.residual_size
if not (len(self.in_weights) == self.pde.residual_size):
raise ValueError("the length of in_weights must match residual_size")
if self.bc_type == "weak":
if len(self.bc_weights) == 1:
self.bc_weights = self.bc_weights * self.pde.bc_residual_size
if not (len(self.bc_weights) == self.pde.bc_residual_size):
raise ValueError(
"the length of bc_weights must match bc_residual_size"
)
self.in_weights = [self.in_weight * w for w in self.in_weights]
self.bc_weights = [self.bc_weight * w for w in self.bc_weights]
if not self.one_loss_by_equation:
default_list_losses = [("residual", torch.nn.MSELoss(), self.in_weights[0])]
else:
default_list_losses = [
("res " + str(i), torch.nn.MSELoss(), self.in_weights[i])
for i in range(0, self.pde.residual_size)
]
if self.bc_type == "weak":
if not self.one_loss_by_equation:
default_list_losses += [("bc", torch.nn.MSELoss(), self.bc_weights[0])]
else:
default_list_losses += [
("bc " + str(i), torch.nn.MSELoss(), self.bc_weights[i])
for i in range(0, self.pde.bc_residual_size)
]
self.data_losses = kwargs.get("data_losses", [])
if not (
isinstance(self.data_losses, list)
and all(isinstance(dl, DataLoss) for dl in self.data_losses)
):
raise ValueError("data_loss argument must be a list of DataLoss instances")
self.dl_weights = kwargs.get("dl_weights", [1.0] * len(self.data_losses))
if not (
isinstance(self.dl_weights, list)
and all(isinstance(dw, float) for dw in self.dl_weights)
and len(self.dl_weights) == len(self.data_losses)
):
raise ValueError(
"self.dl_weights argument must be a list as many floats as data losses"
)
for i, dl in enumerate(self.data_losses):
default_list_losses += [
("data " + str(i), dl.loss_function, self.dl_weights[i])
]
default_losses = GenericLosses(default_list_losses)
self.losses = kwargs.get("losses", default_losses)
[docs]
def get_dof(
self, flag_scope: str = "all", flag_format: str = "list"
) -> torch.Tensor | list:
"""Gets the parameters of the approximation space in use.
Args:
flag_scope: Scope of the degrees of freedom to retrieve.
flag_format: Format of the output, either "list" or "tensor".
Returns:
torch.Tensor | list: Degrees of freedom in the specified format.
"""
iterator_params = self.pde.space.get_dof(flag_scope, flag_format)
if isinstance(self.pde, nn.Module):
dict_param_withoutspace = {
name: param
for name, param in self.pde.named_parameters()
if not name.startswith("space.")
}
if flag_format == "list":
assert isinstance(iterator_params, list)
iterator_params = iterator_params + list(
dict_param_withoutspace.values()
)
if flag_format == "tensor":
iterator_params2 = torch.nn.utils.parameters_to_vector(
list(dict_param_withoutspace.values())
)
iterator_params = torch.cat(
(cast(torch.Tensor, iterator_params), iterator_params2)
)
return iterator_params
[docs]
def evaluate(self, x: LabelTensor, mu: LabelTensor) -> MultiLabelTensor:
"""Evaluates the approximation at given points.
Args:
x: Input tensor for spatial coordinates.
mu: Input tensor for parameters.
Returns:
The evaluated solution.
"""
return self.space.evaluate(x, mu)
[docs]
def sample_all_vars(self, **kwargs: Any) -> tuple[LabelTensor, ...]:
"""Samples collocation points for the PDE and boundary conditions.
Args:
**kwargs: Additional keyword arguments for sampling.
Returns:
tuple[LabelTensor, ...]: tuple of sampled tensors.
"""
n_collocation = kwargs.get("n_collocation", 1000)
x, mu = self.space.integrator.sample(n_collocation)
xmu: tuple[LabelTensor, ...] = (x, mu)
if self.bc_type == "weak":
n_bc_collocation = kwargs.get("n_bc_collocation", 1000)
xnbc, mubc = self.space.integrator.bc_sample(n_bc_collocation, index_bc=0)
xbc, nbc = xnbc[0], xnbc[1]
if TYPE_CHECKING: # pragma: no cover
xbc = cast(LabelTensor, xbc)
nbc = cast(LabelTensor, nbc)
mubc = cast(LabelTensor, mubc)
xmu = xmu + (xbc, nbc, mubc)
return xmu
[docs]
def assembly_post_sampling(
self, data: tuple[LabelTensor, ...], **kwargs
) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
"""Assembles the system of equations post-sampling.
Args:
data: tuple of sampled tensors.
**kwargs: Additional keyword arguments.
Returns:
tuple: tuple containing the assembled operator and right-hand side.
"""
x, mu = data[0], data[1]
w = self.space.evaluate(x, mu)
Lot = self.pde.operator(w, x, mu)
Ft = self.pde.rhs(w, x, mu)
## Lo is a tuple of tensors or of dict in case of different labels
## F is a tuple of tensors or of dict in case of different labels
errormessage = (
"you must reward one residual tensor (batch,1) or a dict of "
"tensors (batch,1) or a tuple of those"
)
Lores: tuple[torch.Tensor, ...] = tuple()
Fres: tuple[torch.Tensor, ...] = tuple()
Lo = (Lot,) if not isinstance(Lot, tuple) else Lot
F = (Ft,) if not isinstance(Ft, tuple) else Ft
assert len(Lo) == len(F), "lhs and rhs must have the same number of components"
for lo, f in zip(Lo, F):
if isinstance(lo, dict):
assert isinstance(f, dict)
for key in lo:
assert key in f
assert lo[key].shape[1] == 1, errormessage
assert f[key].shape[1] == 1, errormessage
Lores = Lores + (lo[key],)
Fres = Fres + (f[key],)
else: # lo, f, are tensors
assert isinstance(lo, torch.Tensor) and isinstance(f, torch.Tensor)
assert lo.shape[1] == 1, errormessage
assert f.shape[1] == 1, errormessage
Lores = Lores + (lo,)
Fres = Fres + (f,)
if self.bc_type == "weak":
xbc, nbc, mubc = data[2], data[3], data[4]
w = self.space.evaluate(xbc, mubc)
Lbct = self.pde.bc_operator(w, xbc, nbc, mubc) ## Lbc is a tuple
Fbct = self.pde.bc_rhs(w, xbc, nbc, mubc) ## Lbc is a tuple
Lbc = (Lbct,) if not isinstance(Lbct, tuple) else Lbct
Fbc = (Fbct,) if not isinstance(Fbct, tuple) else Fbct
assert len(Lbc) == len(Fbc), (
"lhs and rhs must have the same number of components"
)
for lod, fd in zip(Lbc, Fbc):
if isinstance(lod, dict):
assert isinstance(fd, dict)
for key in lod:
assert key in fd
assert lod[key].shape[1] == 1, errormessage
assert fd[key].shape[1] == 1, errormessage
Lores = Lores + (lod[key],)
Fres = Fres + (fd[key],)
else: # lo, f, are tensors
assert isinstance(lo, torch.Tensor) and isinstance(fd, torch.Tensor)
assert lod.shape[1] == 1, errormessage
assert fd.shape[1] == 1, errormessage
Lores = Lores + (lod,)
Fres = Fres + (fd,)
for dl in self.data_losses:
Lores += (self.space.evaluate(*(dl.args)).w,)
Fres += (dl.vals,)
return Lores, Fres
[docs]
def assembly(
self, **kwargs: Any
) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
"""Assembles the system of equations for the PDE.
Assembles also weak boundary conditions if needed.
Args:
**kwargs: Additional keyword arguments.
Returns:
tuple containing the assembled operator and right-hand side.
"""
xmu = self.sample_all_vars(**kwargs)
return self.assembly_post_sampling(xmu, **kwargs)
[docs]
class PreconditionedPinnsElliptic(ABC):
"""A class extending PinnsElliptic with preconditioning.
Args:
**kwargs: Additional keyword arguments.
Keyword Args:
`default_lr` (:code:`float`): The default learning rate used when
linesearch fails. Default : 1e-2.
`type_linesearch` (:code:`str`): The linesearch algorithm:
either "armijo" or "logarithmic_grid". Default: "armijo"
`data_linesearch` (:code:`dict`): optional parameters for the linesearch.
For logarithmic grid: "m" (nb nodes in the grid),
"interval" (min max values of the grid),
"log_basis" the logarithmic basis.
For armijo: "n_step_max" (the max number of steps),
"alpha" and "beta" (the alpha and beta parameters).
"""
def __init__(self, **kwargs: Any):
self.default_lr: float = kwargs.get("default_lr", 1e-2)
opt_1 = {
"name": "sgd",
"optimizer_args": {"lr": self.default_lr},
}
self.optimizer = OptimizerData(opt_1)
self.bool_linesearch = True
self.type_linesearch = kwargs.get("type_linesearch", "armijo")
self.data_linesearch = kwargs.get("data_linesearch", {})
self.data_linesearch.setdefault("m", 10)
self.data_linesearch.setdefault("interval", [0.0, 2.0])
self.data_linesearch.setdefault("log_basis", 2.0)
self.data_linesearch.setdefault("n_step_max", 10)
self.data_linesearch.setdefault("alpha", 0.01)
self.data_linesearch.setdefault("beta", 0.5)
self.bool_preconditioner = True
self.nb_epoch_preconditioner_computing = 1
self.projection_data = {"nonlinear": True, "linear": False, "nb_step": 1}
[docs]
class NaturalGradientPinnsElliptic(PinnsElliptic, PreconditionedPinnsElliptic):
"""A class extending PinnsElliptic with natural gradient preconditioning.
Args:
pde: The elliptic PDE to be solved, represented as an instance of
StrongFormEllipticPDE or LinearOrder2PDE.
bc_type: The type of boundary condition to be applied ("strong" or "weak").
(default: "strong")
**kwargs: Additional keyword arguments.
Keyword Args:
`ng_algo` (:code:`str`): The algorithm for computing the natural gradient
preconditioning matrix. Default : "ENG".
Raises:
ValueError: value for algo keyword argument is not correct.
"""
def __init__(
self,
pde: StrongFormEllipticPDE | LinearOrder2PDE,
bc_type: str = "strong",
**kwargs,
):
# first initialize the PinnsElliptic part
super().__init__(pde, bc_type, **kwargs)
# then initialize the PreconditionedPinnsElliptic part
super(AbstractNonlinearProjector, self).__init__(**kwargs)
default_algo = "ENG"
algo = kwargs.get(NG_ALGO_NAME, default_algo)
# finally initialize the preconditioner
def preconditioner_factory(classname: type):
return classname(
pde.space,
pde,
has_bc=(bc_type == "weak"),
args_for_dl=[dl.args for dl in self.data_losses],
**kwargs,
)
if algo in ENG_VALUES:
self.preconditioner: MatrixPreconditionerPinn = preconditioner_factory(
EnergyNaturalGradientPreconditioner
)
elif algo in ANA_VALUES:
self.preconditioner = preconditioner_factory(AnagramPreconditioner)
elif algo in SNG_VALUES:
self.preconditioner = preconditioner_factory(
SketchyNaturalGradientPreconditioner
)
elif algo in NNG_VALUES:
self.preconditioner = preconditioner_factory(
NystromNaturalGradientPreconditioner
)
else:
raise ValueError(
'value "%s" for optional argument "%s" is not accepted; '
"possible values are: "
'"ENG" or "EnergyNaturalGradient", '
'"ANaGRAM", '
'"NNG" or "NyströmNaturalGradient", '
'"SNG" or "SketchyNaturalGradient".' % (algo, NG_ALGO_NAME)
)
[docs]
class AnagramPinnsElliptic(PinnsElliptic, PreconditionedPinnsElliptic):
"""A class extending PinnsElliptic with Anagram preconditioning.
Args:
pde: The elliptic PDE to be solved, represented as an instance of
StrongFormEllipticPDE or LinearOrder2PDE.
bc_type: The type of boundary condition to be applied ("strong" or "weak").
(default: "strong")
**kwargs: Additional keyword arguments for customization.
"""
def __init__(
self,
pde: StrongFormEllipticPDE | LinearOrder2PDE,
bc_type: str = "strong",
**kwargs,
):
warnings.warn(
"class %s will be deprecated in future versions; "
"please use class NaturalGradientPinnsElliptic"
'with keyword argument %s = "ANaGRAM" '
"instead" % (self.__class__.__name__, NG_ALGO_NAME),
FutureWarning,
)
# first initialize the PinnsElliptic part
super().__init__(pde, bc_type, **kwargs)
# then initialize the PreconditionedPinnsElliptic part
super(AbstractNonlinearProjector, self).__init__(**kwargs)
# finally initialize the preconditioner
self.preconditioner = AnagramPreconditioner(
pde.space,
pde,
has_bc=(bc_type == "weak"),
args_for_dl=[dl.args for dl in self.data_losses],
**kwargs,
)
[docs]
class SketchyNaturalGradientPinnsElliptic(PinnsElliptic, PreconditionedPinnsElliptic):
"""A class extending PinnsElliptic with sketchy natural gradient preconditioning.
Args:
pde: The elliptic PDE to be solved, represented as an instance of
StrongFormEllipticPDE or LinearOrder2PDE.
bc_type: The type of boundary condition to be applied ("strong" or "weak").
(default: "strong")
**kwargs: Additional keyword arguments for customization.
"""
def __init__(
self,
pde: StrongFormEllipticPDE | LinearOrder2PDE,
bc_type: str = "strong",
**kwargs,
):
warnings.warn(
"class %s will be deprecated in future versions; "
"please use class NaturalGradientPinnsElliptic"
'with keyword argument %s = "SketchyNaturalGradient" '
"instead" % (self.__class__.__name__, NG_ALGO_NAME),
FutureWarning,
)
# first initialize the PinnsElliptic part
super().__init__(pde, bc_type, **kwargs)
# then initialize the PreconditionedPinnsElliptic part
super(AbstractNonlinearProjector, self).__init__(**kwargs)
# finally initialize the preconditioner
self.preconditioner = SketchyNaturalGradientPreconditioner(
pde.space,
pde,
has_bc=(bc_type == "weak"),
args_for_dl=[dl.args for dl in self.data_losses],
**kwargs,
)
[docs]
class NystromNaturalGradientPinnsElliptic(PinnsElliptic, PreconditionedPinnsElliptic):
"""A class extending PinnsElliptic with natural gradient preconditioning.
Args:
pde: The elliptic PDE to be solved, represented as an instance of
StrongFormEllipticPDE or LinearOrder2PDE.
bc_type: The type of boundary condition to be applied ("strong" or "weak").
(default: "strong")
**kwargs: Additional keyword arguments for customization.
"""
def __init__(
self,
pde: StrongFormEllipticPDE | LinearOrder2PDE,
bc_type: str = "strong",
**kwargs,
):
warnings.warn(
"class %s will be deprecated in future versions; "
"please use class NaturalGradientPinnsElliptic"
'with keyword argument %s = "NyströmNaturalGradient" '
"instead" % (self.__class__.__name__, NG_ALGO_NAME),
FutureWarning,
)
# first initialize the PinnsElliptic part
super().__init__(pde, bc_type, **kwargs)
# then initialize the PreconditionedPinnsElliptic part
super(AbstractNonlinearProjector, self).__init__(**kwargs)
# finally initialize the preconditioner
self.preconditioner = NystromNaturalGradientPreconditioner(
pde.space,
pde,
has_bc=(bc_type == "weak"),
args_for_dl=[dl.args for dl in self.data_losses],
**kwargs,
)