"""Collocation-based projectors for approximation spaces."""
import warnings
from typing import TYPE_CHECKING, Any, Callable
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.integration.monte_carlo_parameters import ParametricSampler
from scimba_torch.numerical_solvers.abstract_projector import (
RHS_FUNC_TYPE,
AbstractNonlinearProjector,
)
from scimba_torch.numerical_solvers.preconditioner_projector import (
AnagramPreconditionerProjector,
EnergyNaturalGradientPreconditionerProjector,
MatrixPreconditionerProjector,
)
from scimba_torch.optimizers.optimizers_data import OptimizerData
from scimba_torch.utils.scimba_tensors import LabelTensor
NG_ALGO_NAME: str = "ng_algo"
ENG_VALUES: list[str] = ["ENG", "energy", "EnergyNaturalGradient"]
ANA_VALUES: list[str] = ["ANaGRAM", "anagram", "ANAGRAM"]
[docs]
class CollocationProjector(AbstractNonlinearProjector):
"""A collocation-based nonlinear projection method.
This subclass implements methods to assemble the input and output tensors
for a specific nonlinear projection problem using collocation points. It computes
the approximation of a nonlinear problem by sampling collocation points and
evaluating the corresponding function values.
Args:
space: The approximation space where the projection will take place.
rhs: The function representing the right-hand side of the problem.
**kwargs: Additional parameters for the projection, including collocation
points and losses.
"""
def __init__(
self,
space: AbstractApproxSpace,
rhs: RHS_FUNC_TYPE | None = None,
**kwargs,
):
self.rhs: RHS_FUNC_TYPE | None = rhs
super().__init__(space, rhs, **kwargs)
# Rémi why not having rhs here as self.rhs instead of self.space.rhs???
[docs]
def set_rhs(self, rhs: RHS_FUNC_TYPE):
"""Sets the right-hand side function for the projection.
Args:
rhs: The function representing the right-hand side of the problem.
"""
self.rhs = rhs
[docs]
def get_dof(self, flag_scope: str = "all", flag_format: str = "list"):
"""Retrieves the degrees of freedom (DoF) of the approximation space.
Args:
flag_scope: Specifies the scope of the parameters to return.
flag_format: The format for returning the parameters.
Returns:
The degrees of freedom in the specified format.
"""
return self.space.get_dof(flag_scope, flag_format)
[docs]
def metric_matrix(
self,
x: LabelTensor,
mu: LabelTensor,
t: LabelTensor | None = None,
v: LabelTensor | None = None,
**kwargs,
) -> torch.Tensor:
"""Computes the metric matrix for the given tensors.
Args:
x: Input tensor from the spatial domain.
mu: Input tensor from the parameter domain.
t: Input tensor from the time domain (optional).
v: Input tensor from the velocity domain (optional).
**kwargs: Additional arguments.
Returns:
The computed metric matrix.
Raises:
NotImplementedError: If the metric matrix is not defined for the current
space type.
"""
if not (
self.space.type_space == "space" or self.space.type_space == "phase_space"
):
raise NotImplementedError(
"The metric matrix is not yet defined for time-dependent problems."
)
# assert (
# self.space.type_space == "space" or self.space.type_space == "phase_space"
# ), "The metric matrix is not yet defined for time-dependent problems."
N = x.shape[0]
if self.space.type_space == "space":
jacobian = self.space.jacobian(x, mu)
elif self.space.type_space == "time_space":
if TYPE_CHECKING: # pragma: no cover
assert isinstance(t, torch.Tensor)
jacobian = self.space.jacobian(t, x, mu)
else:
if TYPE_CHECKING: # pragma: no cover
assert isinstance(v, torch.Tensor)
jacobian = self.space.jacobian(x, v, mu)
return torch.einsum("ijk,ilk->jl", jacobian, jacobian) / N
[docs]
def sample_all_vars(self, **kwargs: Any) -> tuple[LabelTensor, ...]:
"""Samples values in the domains of the arguments of the function to project.
Args:
**kwargs: Additional arguments for sampling.
Returns:
A tuple containing the sampled tensors.
"""
n_collocation = kwargs.get("n_collocation", 1000)
return tuple(self.space.integrator.sample(n_collocation))
[docs]
def assembly_post_sampling(
self, data: tuple[LabelTensor, ...], **kwargs
) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
"""Assemble the I/O tensors for the nonlinear projection problem after sampling.
Args:
data: The sampled data.
**kwargs: Additional arguments for assembly, including flag_scope.
Returns:
A tuple of tuples containing the assembled left and right-hand sides.
"""
flag_scope = kwargs.get("flag_scope", "all")
with_last_layer = True
if flag_scope == "except_last_layer":
with_last_layer = False
# print("with_last_layer: ", with_last_layer)
if self.space.type_space == "space":
args = [data[0], data[1]]
elif self.space.type_space == "phase_space":
args = [data[0], data[1], data[2]]
else: # time_space
args = [data[0], data[1], data[2]]
u = self.space.evaluate(
*args, with_last_layer=with_last_layer
) # u is a multilabelTensor
assert self.rhs is not None
f = self.rhs(*args) # f is a Tensor
left = (u.w,) if (not with_last_layer) else self.make_tuple(u.get_components())
right = self.make_tuple(f)
return left, right
[docs]
def assembly(
self, **kwargs: Any
) -> tuple[tuple[torch.Tensor, ...], tuple[torch.Tensor, ...]]:
"""Assembles the system of equations for the projection problem.
Args:
**kwargs: Additional arguments for assembly, including the number of
collocation points.
Returns:
A tuple of tuples containing the assembled left and right-hand sides.
"""
data = self.sample_all_vars(**kwargs)
return self.assembly_post_sampling(data, **kwargs)
[docs]
def update_parameter_bounds(self, new_bounds: list[tuple[float, float]]) -> None:
"""Updates the bounds of the parameters in the approximation space.
Args:
new_bounds: A list of tuples containing the new bounds for each parameter.
"""
parameter_sampler = self.space.integrator.list_sampler[-1]
msg = "The last sampler must be a ParametricSampler."
assert isinstance(parameter_sampler, ParametricSampler), msg
msg = (
"The number of new bounds must match the number of parameters. \n"
+ f"You passed the new bounds {new_bounds}, "
+ f"expecting to replace the old bounds {parameter_sampler.bounds}."
)
assert len(new_bounds) == parameter_sampler.dim, msg
parameter_sampler.set_new_bounds(new_bounds)
[docs]
class NaturalGradientProjector(CollocationProjector):
"""Subclass of CollocationProjector using natural gradient optimization.
This class extends the CollocationProjector to use natural gradient optimization
for solving the projection problem.
Args:
space: The approximation space where the projection will take place.
rhs: The function representing the right-hand side of the problem.
**kwargs: Additional parameters for the projection, including collocation points
and losses.
Keyword Args:
`ng_algo` (:code:`str`): The algorithm for computing the natural gradient
preconditioning matrix. Default : "ENG".
`default_lr` (:code:`float`): The default learning rate used when
linesearch fails. Default : 1e-2.
`type_linesearch` (:code:`str`): The linesearch algorithm:
either "armijo" or "logarithmic_grid". Default: "armijo"
`data_linesearch` (:code:`dict`): optional parameters for the linesearch.
For logarithmic grid: "m" (nb nodes in the grid),
"interval" (min max values of the grid),
"log_basis" the logarithmic basis.
For armijo: "n_step_max" (the max number of steps),
"alpha" and "beta" (the alpha and beta parameters).
Raises:
ValueError: value for ng_algo keyword argument is not correct.
"""
def __init__(
self,
space: AbstractApproxSpace,
rhs: RHS_FUNC_TYPE | None = None,
**kwargs,
):
# first initialize the CollocationProjector part
super().__init__(space, rhs, **kwargs)
self.default_lr: float = kwargs.get("default_lr", 1e-2)
opt_1 = {
"name": "sgd",
"optimizer_args": {"lr": self.default_lr},
}
self.optimizer = OptimizerData(opt_1)
self.bool_linesearch = True
self.type_linesearch = kwargs.get("type_linesearch", "armijo")
self.data_linesearch = kwargs.get("data_linesearch", {})
self.data_linesearch.setdefault("M", 10)
self.data_linesearch.setdefault("interval", [0.0, 2.0])
self.data_linesearch.setdefault("log_basis", 2.0)
self.data_linesearch.setdefault("n_step_max", 10)
self.data_linesearch.setdefault("alpha", 0.01)
self.data_linesearch.setdefault("beta", 0.5)
self.bool_preconditioner = True
self.nb_epoch_preconditioner_computing = 1
self.projection_data = {"nonlinear": True, "linear": False, "nb_step": 1}
default_algo = "ENG"
algo = kwargs.get(NG_ALGO_NAME, default_algo)
# finally initialize the preconditioner
def preconditioner_factory(classname: type):
return classname(
space,
has_bc=False,
**kwargs,
)
if algo in ENG_VALUES:
self.preconditioner: MatrixPreconditionerProjector = preconditioner_factory(
EnergyNaturalGradientPreconditionerProjector
)
elif algo in ANA_VALUES:
self.preconditioner = preconditioner_factory(AnagramPreconditionerProjector)
else:
raise ValueError(
'value "%s" for optional argument "%s" is not accepted; '
"possible values are: "
'"ENG" or "EnergyNaturalGradient", '
'"ANaGRAM". ' % (algo, NG_ALGO_NAME)
)
[docs]
class AnagramProjector(CollocationProjector):
"""Subclass of CollocationProjector using anagram-based optimization.
This class extends the CollocationProjector to use anagram-based optimization
for solving the projection problem.
Args:
space: The approximation space where the projection will take place.
rhs: The function representing the right-hand side of the problem.
**kwargs: Additional parameters for the projection, including collocation points
and losses.
"""
def __init__(
self,
space: AbstractApproxSpace,
rhs: RHS_FUNC_TYPE | None = None,
**kwargs,
):
warnings.warn(
"class %s will be deprecated in future versions; "
"please use class NaturalGradientProjector"
'with keyword argument %s = "ANaGRAM" '
"instead" % (self.__class__.__name__, NG_ALGO_NAME),
FutureWarning,
)
super().__init__(space, rhs, **kwargs)
self.default_lr: float = kwargs.get("default_lr", 1e-2)
opt_1 = {
"name": "sgd",
"optimizer_args": {"lr": self.default_lr},
}
self.optimizer = OptimizerData(opt_1)
self.bool_linesearch = True
self.bool_preconditioner = True
self.nb_epoch_preconditioner_computing = 1
self.type_linesearch = kwargs.get("type_linesearch", "armijo")
self.projection_data = {"nonlinear": True, "linear": False, "nb_step": 1}
self.preconditioner = AnagramPreconditionerProjector(
space, has_bc=False, **kwargs
)
self.data_linesearch = kwargs.get("data_linesearch", {})
self.data_linesearch.setdefault("M", 10)
self.data_linesearch.setdefault("interval", [0.0, 2.0])
self.data_linesearch.setdefault("log_basis", 2.0)
self.data_linesearch.setdefault("n_step_max", 10)
self.data_linesearch.setdefault("alpha", 0.01)
self.data_linesearch.setdefault("beta", 0.5)
[docs]
class LinearProjector(CollocationProjector):
"""Subclass of CollocationProjector for linear projection problems.
This class extends the CollocationProjector to handle linear projection problems.
Args:
space: The approximation space where the projection will take place.
rhs: The function representing the right-hand side of the problem.
**kwargs: Additional parameters for the projection, including collocation points
and losses.
"""
def __init__(
self,
space: AbstractApproxSpace,
rhs: Callable,
**kwargs,
):
super().__init__(space, rhs, **kwargs)
self.projection_data = {"nonlinear": False, "linear": True, "nb_step": 1}
[docs]
def plot(f: Callable, sampler: Callable, n_visu: int = 500): # pragma: no cover
"""Plot the function f over a 2D domain using sampled points.
Args:
f: The function to plot, which takes in sampled points and parameters.
sampler: A callable that samples points from the domain.
n_visu: The number of points along each axis for visualization.
"""
import matplotlib.pyplot as plt
from scimba_torch.utils.scimba_tensors import LabelTensor
x, mu = sampler(n_visu**2)
x1 = torch.linspace(0, 1 - 0, n_visu)
x2 = torch.linspace(0, 1 - 0, n_visu)
x1, x2 = torch.meshgrid(x1, x2, indexing="ij")
x = LabelTensor(torch.stack((x1.flatten(), x2.flatten()), dim=1))
x1, x2 = x.get_components()
x1, x2 = x1.detach().cpu(), x2.detach().cpu()
predictions = f(x, mu).detach().cpu()
fig, ax = plt.subplots(1, 1, figsize=(9, 3), constrained_layout=True)
x1 = x1.reshape(n_visu, n_visu)
x2 = x2.reshape(n_visu, n_visu)
predictions = predictions.reshape(n_visu, n_visu)
contour = ax.contourf(x1, x2, predictions, levels=256, cmap="turbo")
fig.colorbar(contour, ax=ax, fraction=0.046, pad=0.04)
ax.contour(x1, x2, predictions, levels=8, colors="white", linewidths=0.5)
ax.set_title("Predictions")
plt.show()