"""Time-discrete numerical solvers for temporal PDEs."""
import copy
from abc import abstractmethod
from typing import Any, Callable, cast
import numpy as np
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.collocation_projector import (
CollocationProjector,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import (
EnergyNaturalGradientPreconditioner,
)
from scimba_torch.numerical_solvers.preconditioner_projector import (
AnagramPreconditionerProjector,
EnergyNaturalGradientPreconditionerProjector,
)
from scimba_torch.optimizers.losses import GenericLosses
from scimba_torch.optimizers.optimizers_data import OptimizerData
from scimba_torch.physical_models.kinetic_pde.abstract_kinetic_pde import KineticPDE
from scimba_torch.physical_models.temporal_pde.abstract_temporal_pde import TemporalPDE
from scimba_torch.utils.scimba_tensors import LabelTensor
[docs]
class TimeDiscreteCollocationProjector(CollocationProjector):
"""Implement the Galerkin-based nonlinear projection method.
This subclass implements the `assembly` method to assemble the input and output
tensors
for a specific nonlinear projection problem using the Galerkin approach.
It computes the approximation of a nonlinear problem by sampling collocation points
and evaluating the corresponding function values.
Args:
pde: The PDE model to solve.
**kwargs: Additional parameters for the projection, including collocation points
and losses.
"""
def __init__(
self,
pde: TemporalPDE | KineticPDE,
**kwargs,
):
self.pde = pde
self.rhs: Callable = pde.rhs
self.bc_rhs: Callable = pde.bc_rhs # dummy assignment
self.bool_weak_bc = kwargs.get("bool_weak_bc", False)
self.bc_weight = kwargs.get("bc_weight", 10.0)
def identity_lhs(space):
def lhs(*args, with_last_layer: bool):
u = space.evaluate(*args, with_last_layer=with_last_layer)
return u.w
return lhs
self.lhs = identity_lhs(pde.space)
super().__init__(pde.space, **kwargs)
if self.type_projection == "L1":
default_losses = (
GenericLosses([("L1", torch.nn.L1Loss(), 1.0)])
if (not self.bool_weak_bc)
else GenericLosses(
[
("L1", torch.nn.L1Loss(), 1.0),
("bc", torch.nn.L1Loss(), self.bc_weight),
]
)
)
elif self.type_projection == "H1":
default_losses = (
GenericLosses(
[
("L2", torch.nn.MSELoss(), 1.0),
("L2 grad", torch.nn.MSELoss(), 0.1),
],
)
if (not self.bool_weak_bc)
else GenericLosses(
[
("L2", torch.nn.MSELoss(), 1.0),
("L2 grad", torch.nn.MSELoss(), 0.1),
("L2 bc", torch.nn.MSELoss(), self.bc_weight),
("L2 grad bc", torch.nn.MSELoss(), 0.1 * self.bc_weight),
]
)
)
else: # Default is L2
default_losses = (
GenericLosses([("L2", torch.nn.MSELoss(), 1.0)])
if (not self.bool_weak_bc)
else GenericLosses(
[
("L2", torch.nn.MSELoss(), 1.0),
("L2 bc", torch.nn.MSELoss(), self.bc_weight),
]
)
)
self.losses = kwargs.get("losses", default_losses)
[docs]
def sample_all_vars(self, **kwargs: Any) -> tuple[LabelTensor, ...]:
"""Samples all variables required for the projection.
Include the collocation and boundary points.
Args:
**kwargs: Additional keyword arguments.
Returns:
A tuple containing sampled collocation points and boundary data.
Raises:
ValueError: If the approximation space type is not recognized.
"""
n_collocation = kwargs.get("n_collocation", 1000)
# if self.space.type_space == "space":
# x, mu = self.space.integrator.sample(n_collocation)
# data = (x, mu)
# elif self.space.type_space == "phase_space":
# x, v, mu = self.space.integrator.sample(n_collocation)
# data = (x, v, mu)
# else: # REMI: ??? should never happen?
# t, x, mu = self.space.integrator.sample(n_collocation)
# data = (t, x, mu)
data = tuple(self.space.integrator.sample(n_collocation))
if self.bool_weak_bc:
n_bc_collocation = kwargs.get("n_bc_collocation", 1000)
if self.space.type_space == "space":
xnbc, mubc = self.space.integrator.bc_sample(
n_bc_collocation, index_bc=0
)
xbc, nbc = xnbc[0], xnbc[1]
mubc = cast(LabelTensor, mubc) # for the static typechecker...
data = data + (xbc, nbc, mubc)
elif self.space.type_space == "phase_space":
# raise NotImplementedError("phase_space")
xnbc, vbc, mubc = self.space.integrator.bc_sample(
n_bc_collocation, index_bc=0
)
xbc, nbc = xnbc[0], xnbc[1]
mubc = cast(LabelTensor, mubc) # for the static typechecker...
data = data + (xbc, vbc, nbc, mubc)
else:
raise ValueError("space should be of type space or phase_space")
# # raise NotImplementedError("time_space")
# # REMI: ??? should never happen?
# tbc, xnbc, mubc = self.space.integrator.bc_sample(
# n_bc_collocation, index_bc=1
# )
# xbc, nbc = xnbc[0], xnbc[1]
# tbc = cast(LabelTensor, tbc) # for the static typechecker...
# mubc = cast(LabelTensor, mubc) # for the static typechecker...
# data = data + (tbc, xbc, nbc, mubc)
return data
[docs]
def assembly_post_sampling(
self, data: tuple[LabelTensor, ...], **kwargs
) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
"""Assembles the I/Otensors for the nonlinear Galerkin projection problem.
This method samples collocation points and evaluates the corresponding function
values from the approximation space and the right-hand side of the problem.
It then returns the tensors representing the inputs and outputs of the
projection.
Args:
data: tuple containing sampled collocation points and boundary data.
**kwargs: Additional keyword arguments.
Returns:
A tuple of tensors representing the inputs (u) and outputs (f) of the
projection problem.
Raises:
ValueError: If the approximation space type is not recognized.
"""
flag_scope = kwargs.get("flag_scope", "all")
with_last_layer = True
if flag_scope == "expect_last_layer":
with_last_layer = False
if self.space.type_space == "space":
args = [data[0], data[1]]
else:
args = [data[0], data[1], data[2]]
# u = self.space.evaluate(
# *args, with_last_layer=with_last_layer
# ) # u is a multilabelTensor
lhs = self.lhs(*args, with_last_layer=with_last_layer)
f = self.rhs(*args) # f is a Tensor
left: tuple[torch.Tensor, ...] = (lhs,)
right: tuple[torch.Tensor, ...] = (f,)
if self.bool_weak_bc:
if self.space.type_space == "space":
# TODO give the good time vector? 0 is OK as far as operator does not
# depend on t
t_ = LabelTensor(0.0 * torch.ones((data[2].shape[0], 1)))
args_for_space_evaluate = [data[2], data[4]] # do not need the normals
args_for_bc_operator = [
t_,
data[2],
data[3],
data[4],
] # need the normals
args_for_bc_rhs_evaluate = [
data[2],
data[3],
data[4],
] # need the normals
elif self.space.type_space == "phase_space":
# raise NotImplementedError("phase_space")
# TODO give the good time vector? 0 is OK as far as operator does not
# depend on t
t_ = LabelTensor(0.0 * torch.ones((data[3].shape[0], 1)))
args_for_space_evaluate = [
data[3],
data[4],
data[6],
] # do not need the normals
args_for_bc_operator = [
t_,
data[3],
data[4],
data[5],
data[6],
] # need the normals
args_for_bc_rhs_evaluate = [
data[3],
data[4],
data[5],
data[6],
] # need the normals
else:
raise ValueError("space should be of type space or phase_space")
ub = self.space.evaluate(
*args_for_space_evaluate, with_last_layer=with_last_layer
) # u is a multilabelTensor
Lbc: torch.Tensor = self.pde.bc_operator(ub, *args_for_bc_operator)
# REMI: consider adding the following, but this puts mess in type checking
# if not with_last_layer: #Lbc is a tuple of b tensors (where b is the size
# of the last hidden layer),
# #should be a tensor
# Lbc =torch.concatenate( Lbc, dim=-1)
fb = self.bc_rhs(*args_for_bc_rhs_evaluate) # f is a Tensor
left = (
lhs,
Lbc,
)
right = (
f,
fb,
)
return left, right
[docs]
class TimeDiscreteNaturalGradientProjector(TimeDiscreteCollocationProjector):
"""A time-discrete natural gradient projector for solving temporal PDEs.
Args:
pde: The PDE model to solve.
**kwargs: Additional parameters for the projection, including collocation points
and losses.
"""
def __init__(
self,
pde: TemporalPDE | KineticPDE,
**kwargs,
):
super().__init__(pde, **kwargs)
self.default_lr: float = kwargs.get("default_lr", 1e-2)
opt_1 = {
"name": "sgd",
"optimizer_args": {"lr": self.default_lr},
}
self.optimizer = OptimizerData(opt_1)
self.bool_linesearch = True
self.bool_preconditioner = True
self.nb_epoch_preconditioner_computing = 1
self.type_linesearch = kwargs.get("type_linesearch", "armijo")
self.projection_data = {"nonlinear": True, "linear": False, "nb_step": 1}
self.preconditioner = EnergyNaturalGradientPreconditionerProjector(
pde.space, has_bc=self.bool_weak_bc, **kwargs
)
self.data_linesearch = kwargs.get("data_linesearch", {})
self.data_linesearch.setdefault("M", 10)
self.data_linesearch.setdefault("interval", [0.0, 2.0])
self.data_linesearch.setdefault("log_basis", 2.0)
self.data_linesearch.setdefault("n_step_max", 10)
self.data_linesearch.setdefault(
"alpha",
0.01,
)
self.data_linesearch.setdefault(
"beta",
0.5,
)
[docs]
class TimeDiscreteImplicitNaturalGradientProjector(TimeDiscreteCollocationProjector):
"""A time-discrete natural gradient for solving temporal PDEs.
Args:
pde: The PDE model to solve.
bc_type: The way the boundary condition is handled;
"strong" for strongly, "weak" for weakly
**kwargs: Additional parameters for the projection, including collocation points
and losses.
"""
def __init__(
self,
pde: TemporalPDE | KineticPDE,
bc_type: str = "strong",
**kwargs,
):
super().__init__(pde, **kwargs)
self.default_lr: float = kwargs.get("default_lr", 1e-2)
opt_1 = {
"name": "sgd",
"optimizer_args": {"lr": self.default_lr},
}
self.optimizer = OptimizerData(opt_1)
self.bool_linesearch = True
self.bool_preconditioner = True
self.nb_epoch_preconditioner_computing = 1
self.type_linesearch = kwargs.get("type_linesearch", "armijo")
self.projection_data = {"nonlinear": True, "linear": False, "nb_step": 1}
self.preconditioner = EnergyNaturalGradientPreconditioner(
pde.space,
pde,
is_operator_linear=pde.linear,
has_bc=(bc_type == "weak"),
**kwargs,
)
self.data_linesearch = kwargs.get("data_linesearch", {})
self.data_linesearch.setdefault("M", 10)
self.data_linesearch.setdefault("interval", [0.0, 2.0])
self.data_linesearch.setdefault("log_basis", 2.0)
self.data_linesearch.setdefault("nbMaxSteps", 10)
self.data_linesearch.setdefault(
"alpha",
0.01,
)
self.data_linesearch.setdefault(
"beta",
0.5,
)
[docs]
class TimeDiscreteAnagramProjector(TimeDiscreteCollocationProjector):
"""A time-discrete natural gradient projector for solving temporal PDEs.
Args:
pde: The PDE model to solve.
**kwargs: Additional parameters for the projection, including collocation points
and losses.
"""
def __init__(
self,
pde: TemporalPDE | KineticPDE,
**kwargs,
):
super().__init__(pde, **kwargs)
self.default_lr: float = kwargs.get("default_lr", 1e-2)
opt_1 = {
"name": "sgd",
"optimizer_args": {"lr": self.default_lr},
}
self.optimizer = OptimizerData(opt_1)
self.bool_linesearch = True
self.bool_preconditioner = True
self.nb_epoch_preconditioner_computing = 1
self.type_linesearch = kwargs.get("type_linesearch", "armijo")
self.projection_data = {"nonlinear": True, "linear": False, "nb_step": 1}
self.preconditioner = AnagramPreconditionerProjector(
pde.space, has_bc=self.bool_weak_bc, **kwargs
)
self.data_linesearch = kwargs.get("data_linesearch", {})
self.data_linesearch.setdefault("M", 10)
self.data_linesearch.setdefault("interval", [0.0, 2.0])
self.data_linesearch.setdefault("log_basis", 2.0)
self.data_linesearch.setdefault("n_step_max", 10)
self.data_linesearch.setdefault(
"alpha",
0.01,
)
self.data_linesearch.setdefault(
"beta",
0.5,
)
[docs]
class ExplicitTimeDiscreteScheme:
"""An explicit time-discrete scheme for solving a differential equation.
Use linear and/or nonlinear spaces.
The class supports initialization of the model with a target function, computation
of the right-hand side (RHS)
from the model, and stepping through time using a projector.
Args:
pde: The PDE model.
projector: The projector for training the model.
projector_init: The projector for initializing the model
(if None, uses the same as projector).
**kwargs: Additional hyperparameters for the scheme.
"""
def __init__(
self,
pde: TemporalPDE | KineticPDE,
projector: TimeDiscreteCollocationProjector,
projector_init: TimeDiscreteCollocationProjector | None = None,
**kwargs,
):
self.pde = pde
self.projector = projector
if projector_init is None:
self.projector_init = projector
else:
self.projector_init = projector_init
self.initial_time: float = kwargs.get("initial_time", 0.0)
self.initial_space: AbstractApproxSpace = copy.deepcopy(self.projector.space)
self.saved_times: list[float] = []
self.saved_spaces: list[AbstractApproxSpace] = []
self.bool_weak_bc = kwargs.get("bool_weak_bc", False)
self.final_time: float = kwargs.get("T", 0.1)
if hasattr(pde, "exact_solution"):
self.exact_solution = pde.exact_solution
else:
self.exact_solution = None
[docs]
def initialization(self, **kwargs: Any):
"""Initializes the model by projecting the initial condition onto the model.
Args:
**kwargs: Additional parameters for the initialization, such as the number
of epochs and collocation points.
"""
self.rhs = self.pde.initial_condition # rhs is the initial condition
self.projector_init.set_rhs(self.rhs)
self.projector_init.solve(**kwargs) # Use projector to fit the model
if "initial_time" in kwargs:
self.initial_time = kwargs["initial_time"]
self.initial_space = copy.deepcopy(self.projector.space)
[docs]
@abstractmethod
def update(self, t: float, dt: float, **kwargs):
"""Updates the model parameters using the time step and the chosen time scheme.
Args:
t: The current time.
dt: The time step.
**kwargs: Additional parameters for the update.
"""
[docs]
def compute_relative_error_in_time(
self, t: float, n_error: int = 5_000
) -> list[float | torch.Tensor]:
"""Computes the relative error between the current and exact solution.
Args:
t: The time at which the error is computed.
n_error: The number of points used for computing the error.
Default is 5_000.
Returns:
list: The L1, L2, and Linf errors.
"""
x, mu = self.pde.space.integrator.sample(n_error)
u = self.pde.space.evaluate(x, mu)
t_ = LabelTensor(t * torch.ones((x.shape[0], 1)))
u_exact = self.exact_solution(t_, x, mu)
error = u.w - u_exact
# if relative := torch.min(torch.abs(u_exact)) > 1e-3:
# error = error / u_exact
with torch.no_grad():
L1_error = torch.mean(torch.abs(error))
L2_error = torch.mean(error**2) ** 0.5
Linf_error = torch.max(torch.abs(error))
return [t, L1_error, L2_error, Linf_error]
[docs]
def solve(
self,
dt: float = 1e-5,
final_time: float = 0.1,
sol_exact: Callable | None = None,
**kwargs,
):
"""Solves the full time-dependent problem, using time_step.
Args:
dt: The time step.
final_time: The final time.
sol_exact: The exact solution, if available.
**kwargs: Additional parameters for the time-stepping, such as the number
of epochs, collocation points, and options for saving and plotting.
"""
self.nb_keep = kwargs.get("nb_keep", 1)
inter_times = np.linspace(self.initial_time, final_time, self.nb_keep + 2)[1:-1]
self.saved_times = []
self.saved_spaces = []
index_of_next_t_to_be_saved = 0
self.final_time = final_time
nt = 0
time = self.initial_time
save = kwargs.get("save", None)
plot = kwargs.get("plot", None)
additional_epochs_for_first_iteration = kwargs.get(
"additional_epochs_for_first_iteration", 0
)
epochs = kwargs.get("epochs", 100)
if self.exact_solution:
self.list_errors = [self.compute_relative_error_in_time(0)]
print(f"Time: {time}, L2 error: {self.list_errors[-1][2]:.3e}", flush=True)
while time < final_time:
if time + dt > final_time:
dt = final_time - time
if final_time - time - dt < 1e-16:
dt = final_time - time
if (nt == 0) and ("epochs" in kwargs):
epochs += additional_epochs_for_first_iteration
kwargs["epochs"] = epochs
self.update(time, dt, **kwargs)
if (nt == 0) and ("epochs" in kwargs):
epochs -= additional_epochs_for_first_iteration
kwargs["epochs"] = epochs
time = time + dt
nt = nt + 1
if plot:
assert hasattr(self.pde.space, "integrator")
plot(
self.pde.space.evaluate,
self.pde.space.integrator.sample,
T=time,
iter=nt,
)
if save:
self.projector.save(f"{nt}_{save}")
if self.exact_solution:
self.list_errors.append(self.compute_relative_error_in_time(time))
print(
f"Time: {time}, L2 error: {self.list_errors[-1][2]:.3e}", flush=True
)
if (
(time < final_time)
and (index_of_next_t_to_be_saved < self.nb_keep)
and (time >= inter_times[index_of_next_t_to_be_saved])
):
self.saved_times.append(time)
self.saved_spaces.append(copy.deepcopy(self.projector.space))
index_of_next_t_to_be_saved += 1
# if sol_exact is not None:
# error = self.compute_relative_error_in_time(time, sol_exact)
# self.errors.append(error)
# print(f"current iteration : {nt}, error: {error:.2e}")
# else:
# print("current iteration :", nt)
# # self.list_err.append(err_abs)
# nt = nt + 1
if plot:
assert hasattr(self.pde.space, "integrator")
plot(
self.pde.space.evaluate,
self.pde.space.integrator.sample,
T=time,
iter=nt,
)
# self.times.append(time)
if self.exact_solution:
self.errors = torch.tensor(self.list_errors)