Source code for scimba_torch.physical_models.ode.simple_ode

"""Implementation of a simple 1D ODE."""

from typing import Callable, cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.physical_models.ode.abstract_ode import AbstractODE
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
from scimba_torch.utils.typing_protocols import VarArgCallable


[docs] class SimpleODE(AbstractODE): r"""Implementation of a 1D simple ODE. .. math:: \frac{du}{dt} = \\mu u, \\quad t \\in (0, T), with an initial condition given by the function `init`. Args: space: The approximation space for the problem init: Callable for the initial condition **kwargs: Additional keyword arguments """ def __init__( self, space: AbstractApproxSpace, init: Callable, **kwargs, ): super().__init__(space, init=init, n_equations=1, **kwargs)
[docs] def time_operator( self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the ODE operator. Args: w: Solution tensor t: Temporal coordinate tensor mu: Parameter tensor Returns: Operator tensor """ u = w.get_components() alpha = mu.get_components() if isinstance(u, torch.Tensor): return self.grad(u, t) - alpha * u else: return tuple(cast(torch.Tensor, self.grad(ui, t) - alpha * ui) for ui in u)
[docs] def functional_operator( self, func: VarArgCallable, t: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: """Compute the functional operator. Args: func: Callable representing the function t: Temporal coordinate tensor mu: Parameter tensor theta: Additional parameters tensor Returns: Functional operator tensor """ time_op = torch.func.jacrev(func, 0)(t, mu, theta) return (time_op)[0, 0] - mu[0] * func(t, mu, theta)