"""Define the ODEPinns class, which is a subclass of CollocationProjector.
It is designed to solve ordinary differential equations (ODEs)
using physics-informed neural networks (PINNs).
"""
from abc import ABC
from typing import Any
import torch
import torch.nn as nn
from scimba_torch.numerical_solvers.abstract_projector import AbstractNonlinearProjector
from scimba_torch.numerical_solvers.collocation_projector import (
CollocationProjector,
)
from scimba_torch.numerical_solvers.elliptic_pde.pinns import (
ANA_VALUES,
ENG_VALUES,
NG_ALGO_NAME,
NNG_VALUES,
SNG_VALUES,
_check_and_format_weight_argument,
)
from scimba_torch.numerical_solvers.pinn_preconditioners import (
EnergyNaturalGradientPreconditioner,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import (
MatrixPreconditionerPinn,
)
from scimba_torch.optimizers.losses import DataLoss, GenericLosses
from scimba_torch.optimizers.optimizers_data import OptimizerData
from scimba_torch.physical_models.ode.abstract_ode import AbstractODE
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
[docs]
class ODEPinns(CollocationProjector):
"""A class to solve ODEs using Physics-Informed Neural Networks.
Args:
ode: The ODE to be solved
ic_type: The type of initial condition to be applied ("strong" or "weak").
(default: "strong")
**kwargs: Additional keyword arguments for customization.
Raises:
ValueError: when the lengths of in_weights or ic_weights
does not match residual_size or ic_residual_size
"""
def __init__(self, ode: AbstractODE, ic_type: str = "strong", **kwargs):
super().__init__(ode.space, ode.rhs, **kwargs)
self.ode = ode
self.ic_type = ic_type
self.space = ode.space
self.one_loss_by_equation = kwargs.get("one_loss_by_equation", False)
# in/bc balance
self.in_weight = kwargs.get("in_weight", 1.0)
self.ic_weight = kwargs.get("ic_weight", 10.0)
# in case of several equations/labels, balance between equations of residual
in_weights = kwargs.get("in_weights", 1.0)
self.in_weights = _check_and_format_weight_argument(in_weights)
# in case of several equations/labels, balance between equations of ic
ic_weights = kwargs.get("ic_weights", 1.0)
self.ic_weights = _check_and_format_weight_argument(ic_weights)
if self.one_loss_by_equation:
if len(self.in_weights) == 1:
self.in_weights = self.in_weights * self.ode.residual_size
if not (len(self.in_weights) == self.ode.residual_size):
raise ValueError("the length of in_weights must match residual_size")
if self.ic_type == "weak":
if len(self.ic_weights) == 1:
self.ic_weights = self.ic_weights * self.ode.ic_residual_size
if not (len(self.ic_weights) == self.ode.ic_residual_size):
raise ValueError(
"the length of ic_weights must match ic_residual_size"
)
self.in_weights = [self.in_weight * w for w in self.in_weights]
self.ic_weights = [self.ic_weight * w for w in self.ic_weights]
if not self.one_loss_by_equation:
default_list_losses = [("residual", torch.nn.MSELoss(), self.in_weights[0])]
else:
default_list_losses = [
("res " + str(i), torch.nn.MSELoss(), self.in_weights[i])
for i in range(0, self.ode.residual_size)
]
if self.ic_type == "weak":
if not self.one_loss_by_equation:
default_list_losses += [("ic", torch.nn.MSELoss(), self.ic_weights[0])]
else:
default_list_losses += [
("ic " + str(i), torch.nn.MSELoss(), self.ic_weights[i])
for i in range(0, self.ode.ic_residual_size)
]
self.data_losses = kwargs.get("data_losses", [])
if not (
isinstance(self.data_losses, list)
and all(isinstance(dl, DataLoss) for dl in self.data_losses)
):
raise ValueError("data_loss argument must be a list of DataLoss instances")
self.dl_weights = kwargs.get("dl_weights", [1.0] * len(self.data_losses))
if not (
isinstance(self.dl_weights, list)
and all(isinstance(dw, float) for dw in self.dl_weights)
and len(self.dl_weights) == len(self.data_losses)
):
raise ValueError(
"self.dl_weights argument must be a list as many floats as data losses"
)
for i, dl in enumerate(self.data_losses):
default_list_losses += [
("data " + str(i), dl.loss_function, self.dl_weights[i])
]
default_losses = GenericLosses(default_list_losses)
self.losses = kwargs.get("losses", default_losses)
[docs]
def get_dof(
self, flag_scope: str = "all", flag_format: str = "list"
) -> torch.Tensor | list:
"""Gets the parameters of the approximation space in use.
Args:
flag_scope: Scope of the degrees of freedom to retrieve.
flag_format: Format of the output, either "list" or "tensor".
Returns:
Degrees of freedom in the specified format.
"""
iterator_params = self.ode.space.get_dof(flag_scope, flag_format)
if isinstance(self.ode, nn.Module):
dict_param_withoutspace = {
name: param
for name, param in self.ode.named_parameters()
if not name.startswith("space.")
}
if flag_format == "list":
iterator_params = iterator_params + list(
dict_param_withoutspace.values()
)
if flag_format == "tensor":
iterator_params2 = torch.nn.utils.parameters_to_vector(
list(dict_param_withoutspace.values())
)
iterator_params = torch.cat((iterator_params, iterator_params2))
return iterator_params
[docs]
def evaluate(self, t: torch.Tensor, mu: torch.Tensor) -> MultiLabelTensor:
"""Evaluates the approximation at given points.
Args:
t: Input tensor for time coordinates.
mu: Input tensor for parameters.
Returns:
The evaluated solution.
"""
return self.space.evaluate(t, mu)
[docs]
def sample_all_vars(self, **kwargs: Any) -> dict[str, tuple[LabelTensor, ...]]:
"""Samples collocation points for the ODE and initial conditions.
Args:
**kwargs: Additional keyword arguments for sampling.
Returns:
Dictionary of sampled tensors.
"""
# initialize dictionary of sampled points
tmu = {}
# sample inner points
n_collocation = kwargs.get("n_collocation", 1000)
t, mu = self.space.integrator.sample(n_collocation)
tmu["inner"] = (t, mu)
# sample initial points, if weak IC
if self.ic_type == "weak":
n_ic_collocation = kwargs.get("n_ic_collocation", 1000)
_, muic = self.space.integrator.sample(n_ic_collocation)
tmu["ic"] = (muic,)
# return all sampled points
return tmu
[docs]
def assembly_post_sampling(self, tmu: dict, **kwargs) -> tuple:
"""Assembles the system of equations post-sampling.
Args:
tmu: dictionary of sampled tensors.
**kwargs: Additional keyword arguments.
Returns:
Tuple containing the assembled operator and right-hand side.
"""
# inner points: ode residual and rhs
t, mu = tmu["inner"]
w = self.space.evaluate(t, mu)
L_time = self.ode.time_operator(w, t, mu) # tuple
if isinstance(L_time, tuple):
L = tuple(L_t for L_t in L_time)
else:
assert isinstance(L_time, torch.Tensor), (
"time operator must retrieve a tensor"
)
L = L_time
f = self.ode.rhs(w, t, mu) # tuple
Lo = self.make_tuple(L)
f = self.make_tuple(f)
if self.ic_type == "weak":
# ic points: initial condition
(muic,) = tmu["ic"]
tic = LabelTensor(torch.zeros((muic.shape[0], 1)))
w = self.space.evaluate(tic, muic)
fic = self.ode.init(muic) # tuple
Lic = self.make_tuple(w.w)
fic = self.make_tuple(fic)
Lo = Lo + Lic
f = f + fic
for dl in self.data_losses:
Lo += (self.space.evaluate(*(dl.args)).w,)
f += (dl.vals,)
return Lo, f
[docs]
def assembly(self, **kwargs: Any) -> tuple:
"""Assembles the system of equations for the ODE.
Args:
**kwargs: Additional keyword arguments.
Returns:
Tuple containing the assembled operator and right-hand side.
"""
xmu = self.sample_all_vars(**kwargs)
return self.assembly_post_sampling(xmu, **kwargs)
[docs]
class PreconditionedODEPinns(ABC):
"""A class extending ODEPinns with preconditioning.
Args:
**kwargs: Additional keyword arguments for customization.
Keyword Args:
`default_lr` (:code:`float`): The default learning rate used when
linesearch fails. Default : 1e-2.
`type_linesearch` (:code:`str`): The linesearch algorithm:
either "armijo" or "logarithmic_grid". Default: "armijo"
`data_linesearch` (:code:`dict`): optional parameters for the linesearch.
For logarithmic grid: "m" (nb nodes in the grid),
"interval" (min max values of the grid),
"log_basis" the logarithmic basis.
For armijo: "n_step_max" (the max number of steps),
"alpha" and "beta" (the alpha and beta parameters).
"""
def __init__(self, **kwargs: Any):
self.default_lr: float = kwargs.get("default_lr", 1e-2)
opt_1 = {
"name": "sgd",
"optimizer_args": {"lr": self.default_lr},
}
self.optimizer = OptimizerData(opt_1)
self.bool_linesearch = True
self.type_linesearch = kwargs.get("type_linesearch", "armijo")
self.data_linesearch = kwargs.get("data_linesearch", {})
self.data_linesearch.setdefault("M", 10)
self.data_linesearch.setdefault("interval", [0.0, 2.0])
self.data_linesearch.setdefault("log_basis", 2.0)
self.data_linesearch.setdefault("n_step_max", 10)
self.data_linesearch.setdefault("alpha", 0.01)
self.data_linesearch.setdefault("beta", 0.5)
self.bool_preconditioner = True
self.nb_epoch_preconditioner_computing = 1
self.projection_data = {"nonlinear": True, "linear": False, "nb_step": 1}
[docs]
class NaturalGradientODEPinns(ODEPinns, PreconditionedODEPinns):
"""A class extending ODEPinns with natural gradient preconditioning.
Args:
ode: The ODE to be solved.
ic_type: Type of initial condition ("strong" or "weak").
Defaults to "strong".
**kwargs: Additional keyword arguments for customization.
Keyword Args:
`ng_algo` (:code:`str`): The algorithm for computing the natural gradient
preconditioning matrix. Default : "ENG".
Raises:
ValueError: value for ng_algo keyword argument is not correct.
NotImplementedError: value for ng_algo keyword argument is not implemented.
"""
def __init__(self, ode: AbstractODE, ic_type: str = "strong", **kwargs):
# first initialize the ODEPinns part
super().__init__(ode, ic_type, **kwargs)
# then initialize the PreconditionedODEPinns part
super(AbstractNonlinearProjector, self).__init__(**kwargs)
default_algo = "ENG"
algo = kwargs.get(NG_ALGO_NAME, default_algo)
# finally initialize the preconditioner
def preconditioner_factory(classname: type):
return classname(
ode.space,
ode,
has_bc=False,
has_ic=(ic_type == "weak"),
args_for_dl=[dl.args for dl in self.data_losses],
**kwargs,
)
if algo in ENG_VALUES:
self.preconditioner: MatrixPreconditionerPinn = preconditioner_factory(
EnergyNaturalGradientPreconditioner
)
elif algo in ANA_VALUES:
raise NotImplementedError(
"AnagramPreconditioner is not implemented yet for ODEs"
)
elif algo in SNG_VALUES:
raise NotImplementedError(
"SketchyNaturalGradientPreconditioner is not implemented yet for ODEs"
)
elif algo in NNG_VALUES:
raise NotImplementedError(
"NystromNaturalGradientPreconditioner is not implemented yet for ODEs"
)
else:
raise ValueError(
'value "%s" for optional argument "%s" is not accepted; '
"possible values are: "
'"ENG" or "EnergyNaturalGradient", '
'"ANaGRAM", '
'"NNG" or "NyströmNaturalGradient", '
'"SNG" or "SketchyNaturalGradient".' % (algo, NG_ALGO_NAME)
)