"""Implementation of a rigid body ODE."""
from typing import Callable
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.physical_models.ode.abstract_ode import AbstractODE
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
from scimba_torch.utils.typing_protocols import VarArgCallable
[docs]
class RigidBody(AbstractODE):
r"""Implementation of a rigid body ODE.
.. math::
\frac{dx}{dt} = y z + f_x(t), \\
\frac{dy}{dt} = - x z + f_y(t), \\
\frac{dz}{dt} = x y + f_z(t), \\
with an initial condition given by the function `init`
and a right-hand side given by the function `rhs_func`,
giving the three components of the right-hand side
:math:`f_x`, :math:`f_y`, and :math:`f_z`.
Args:
space: The approximation space for the problem
init: Callable for the initial condition
rhs_func: Callable for the right-hand side of the ODE (default is zeros_rhs)
**kwargs: Additional keyword arguments
"""
def __init__(
self,
space: AbstractApproxSpace,
init: Callable,
rhs_func: Callable = None,
**kwargs,
):
super().__init__(space, init=init, rhs_func=rhs_func, n_equations=3, **kwargs)
[docs]
@staticmethod
def compute_eqs(
dx_dt: torch.Tensor,
dy_dt: torch.Tensor,
dz_dt: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
i_x: torch.Tensor,
i_y: torch.Tensor,
i_z: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the equations of the rigid body ODE.
Args:
dx_dt: Time derivative of x
dy_dt: Time derivative of y
dz_dt: Time derivative of z
x: Variable x
y: Variable y
z: Variable z
i_x: Parameter i_x
i_y: Parameter i_y
i_z: Parameter i_z
Returns:
A tuple containing the equations for x, y, and z
"""
eq_x = i_x * dx_dt - (i_z - i_y) * y * z
eq_y = i_y * dy_dt - (i_x - i_z) * x * z
eq_z = i_z * dz_dt - (i_y - i_x) * x * y
return eq_x, eq_y, eq_z
[docs]
def time_operator(
self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the ODE operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
mu: Parameter tensor
Returns:
Operator tensor
"""
x, y, z = w.get_components()
i_x, i_y, i_z = mu.get_components()
dx_dt, dy_dt, dz_dt = self.grad(x, t), self.grad(y, t), self.grad(z, t)
return self.compute_eqs(dx_dt, dy_dt, dz_dt, x, y, z, i_x, i_y, i_z)
[docs]
def functional_operator(
self,
func: VarArgCallable,
t: torch.Tensor,
mu: torch.Tensor,
theta: torch.Tensor,
) -> torch.Tensor:
"""Compute the functional operator.
Args:
func: Callable representing the function
t: Temporal coordinate tensor
mu: Parameter tensor
theta: Additional parameters tensor
Returns:
Functional operator tensor
"""
x, y, z = func(t, mu, theta)
i_x, i_y, i_z = mu[0], mu[1], mu[2]
time_op = torch.func.jacrev(func, 0)(t, mu, theta)
dx_dt, dy_dt, dz_dt = time_op[0, 0], time_op[1, 0], time_op[2, 0]
eqs = self.compute_eqs(dx_dt, dy_dt, dz_dt, x, y, z, i_x, i_y, i_z)
return torch.stack(eqs, dim=-1)