Source code for scimba_torch.physical_models.ode.damped_harmonic_oscillator

"""Implementation of a damped harmonic oscillator ODE."""

from typing import Callable

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.physical_models.ode.abstract_ode import AbstractODE
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
from scimba_torch.utils.typing_protocols import VarArgCallable


[docs] class DampedHarmonicOscillator(AbstractODE): r"""Implementation of a damped harmonic oscillator ODE. .. math:: \frac{dx}{dt} = - \mu x - y, \\ \frac{dy}{dt} = x - \mu y, \\ with an initial condition given by the function `init`. Args: space: The approximation space for the problem init: Callable for the initial condition **kwargs: Additional keyword arguments """ def __init__(self, space: AbstractApproxSpace, init: Callable, **kwargs): super().__init__(space, init=init, n_equations=2, **kwargs)
[docs] def time_operator( self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the ODE operator. Args: w: Solution tensor t: Temporal coordinate tensor mu: Parameter tensor Returns: Operator tensor """ x, y = w.get_components() mu = mu.get_components() dx_dt = self.grad(x, t) dy_dt = self.grad(y, t) return dx_dt + mu * x + y, dy_dt - x + mu * y
[docs] def functional_operator( self, func: VarArgCallable, t: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: """Compute the functional operator. Args: func: Callable representing the function t: Temporal coordinate tensor mu: Parameter tensor theta: Additional parameters tensor Returns: Functional operator tensor """ time_op = torch.func.jacrev(func, 0)(t, mu, theta) dx_dt, dy_dt = time_op[0, 0], time_op[1, 0] x, y = func(t, mu, theta) return torch.stack([dx_dt + mu[0] * x + y, dy_dt - x + mu[0] * y], dim=-1)