Source code for scimba_torch.physical_models.ode.rigid_body

"""Implementation of a rigid body ODE."""

from typing import Callable

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.physical_models.ode.abstract_ode import AbstractODE
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
from scimba_torch.utils.typing_protocols import VarArgCallable


[docs] class RigidBody(AbstractODE): r"""Implementation of a rigid body ODE. .. math:: \frac{dx}{dt} = y z + f_x(t), \\ \frac{dy}{dt} = - x z + f_y(t), \\ \frac{dz}{dt} = x y + f_z(t), \\ with an initial condition given by the function `init` and a right-hand side given by the function `rhs_func`, giving the three components of the right-hand side :math:`f_x`, :math:`f_y`, and :math:`f_z`. Args: space: The approximation space for the problem init: Callable for the initial condition rhs_func: Callable for the right-hand side of the ODE (default is zeros_rhs) **kwargs: Additional keyword arguments """ def __init__( self, space: AbstractApproxSpace, init: Callable, rhs_func: Callable = None, **kwargs, ): super().__init__(space, init=init, rhs_func=rhs_func, n_equations=3, **kwargs)
[docs] @staticmethod def compute_eqs( dx_dt: torch.Tensor, dy_dt: torch.Tensor, dz_dt: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, i_x: torch.Tensor, i_y: torch.Tensor, i_z: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the equations of the rigid body ODE. Args: dx_dt: Time derivative of x dy_dt: Time derivative of y dz_dt: Time derivative of z x: Variable x y: Variable y z: Variable z i_x: Parameter i_x i_y: Parameter i_y i_z: Parameter i_z Returns: A tuple containing the equations for x, y, and z """ eq_x = i_x * dx_dt - (i_z - i_y) * y * z eq_y = i_y * dy_dt - (i_x - i_z) * x * z eq_z = i_z * dz_dt - (i_y - i_x) * x * y return eq_x, eq_y, eq_z
[docs] def time_operator( self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the ODE operator. Args: w: Solution tensor t: Temporal coordinate tensor mu: Parameter tensor Returns: Operator tensor """ x, y, z = w.get_components() i_x, i_y, i_z = mu.get_components() dx_dt, dy_dt, dz_dt = self.grad(x, t), self.grad(y, t), self.grad(z, t) return self.compute_eqs(dx_dt, dy_dt, dz_dt, x, y, z, i_x, i_y, i_z)
[docs] def functional_operator( self, func: VarArgCallable, t: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: """Compute the functional operator. Args: func: Callable representing the function t: Temporal coordinate tensor mu: Parameter tensor theta: Additional parameters tensor Returns: Functional operator tensor """ x, y, z = func(t, mu, theta) i_x, i_y, i_z = mu[0], mu[1], mu[2] time_op = torch.func.jacrev(func, 0)(t, mu, theta) dx_dt, dy_dt, dz_dt = time_op[0, 0], time_op[1, 0], time_op[2, 0] eqs = self.compute_eqs(dx_dt, dy_dt, dz_dt, x, y, z, i_x, i_y, i_z) return torch.stack(eqs, dim=-1)