Source code for scimba_torch.plots.plot_time_discrete_scheme

"""Plotting functions (generic in geometric dimension) for time-discrete schemes."""

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.domain.mesh_based_domain.cuboid import Cuboid
from scimba_torch.domain.meshless_domain.base import SurfacicDomain, VolumetricDomain
from scimba_torch.integration.monte_carlo import TensorizedSampler
from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler
from scimba_torch.numerical_solvers.temporal_pde.time_discrete import (
    ExplicitTimeDiscreteScheme,
)
from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces


[docs] def plot_time_discrete_scheme( scheme: ExplicitTimeDiscreteScheme, **kwargs, ) -> None: """Plot an approximation space obtained with a time-discrete solver. Args: scheme: the time-discrete solver **kwargs: arbitrary keyword arguments Keyword Args: parameters_values: a (list of) point(s) in the parameters domain, or "mean" or "random", defaults to "mean", velocity_values: a list of point(s) in the velocity domain, derivatives: a list of strings representing the derivatives to be plot, for instance "uxx"; defaults to [], solution: a callable depending on the same args as space to be plot, error: plot the absolute error with respect to the given solution, title: a str ...: see examples Implemented only for 1 and 2 dimensional spaces Raises: AttributeError: some arguments are missing necessary attributes NotImplementedError: some option combinations are not implemented yet Examples: >>> import matplotlib.pyplot as plt >>> from scimba_torch.plots.plot_time_discrete_scheme import plot_time_discrete_scheme >>> ... >>> def exact(t: LabelTensor, x: LabelTensor, mu: LabelTensor): ... >>> plot_time_discrete_scheme( scheme, solution=exact, error=exact, derivatives=["ux"], ) >>> plt.show() """ spaces = ( (scheme.initial_space,) + tuple(space for space in scheme.saved_spaces) + (scheme.projector.space,) ) # if not hasattr(scheme.pde, "space"): raise AttributeError("scheme.pde must have an attribute space") assert hasattr(scheme.pde, "space") assert isinstance(scheme.pde.space, AbstractApproxSpace) lspace: AbstractApproxSpace = scheme.pde.space if not hasattr(lspace, "spatial_domain"): raise AttributeError("scheme.pde.space must have an attribute spatial_domain") assert hasattr(lspace, "spatial_domain") assert isinstance(lspace.spatial_domain, VolumetricDomain) or isinstance( lspace.spatial_domain, Cuboid ) if not hasattr(lspace, "type_space"): raise AttributeError("scheme.pde.space must have an attribute type_space") assert hasattr(lspace, "type_space") if lspace.type_space not in ["space", "phase_space"]: raise NotImplementedError( "plot_time_discrete_scheme not implemented for type_space %s" % lspace.type_space ) spatial_domain: VolumetricDomain | Cuboid = lspace.spatial_domain velocity_domain: SurfacicDomain | VolumetricDomain | None = None if lspace.type_space == "phase_space": if not hasattr(lspace, "velocity_domain"): raise AttributeError( "scheme.pde.space must have an attribute velocity_domain" ) assert hasattr(lspace, "velocity_domain") assert isinstance(lspace.velocity_domain, SurfacicDomain) or isinstance( lspace.velocity_domain, VolumetricDomain ) velocity_domain = (lspace.velocity_domain,) if not hasattr(lspace, "integrator"): raise AttributeError("scheme.pde.space must have an attribute integrator") assert hasattr(lspace, "integrator") assert isinstance(lspace.integrator, TensorizedSampler) parameters_domain: list[tuple[float, float]] = [] for sam in lspace.integrator.list_sampler: if isinstance(sam, UniformParametricSampler): parameters_domain = sam.bounds # print("spaces: ", spaces) # print("spatial_domain: ", spatial_domain) time_domain = [float(scheme.initial_time), float(scheme.final_time)] time_values = ( (float(scheme.initial_time),) + tuple(tval for tval in scheme.saved_times) + (float(scheme.final_time),) ) # print("time_domain: ", time_domain) # print("time_values: ", time_values) # kwargs.pop("time_values", None) losses = ( (scheme.projector.losses,) + tuple(None for _ in scheme.saved_spaces) + (None,) ) kwargs.pop("loss", None) # print("kwargs: ", kwargs) plot_abstract_approx_spaces( spaces, spatial_domain, parameters_domain, time_domain, velocity_domain, loss=losses, time_values=time_values, time_discrete=True, **kwargs, )