"""Create solutions using various ODE integrators."""
from typing import Callable
import torch
from scimba_torch.flows.integrators_ode import (
euler_midpoint,
euler_symplectic,
rk4,
rk4_symplectic,
verlet_explicit,
verlet_implicit,
)
[docs]
def create_solution(
x0: torch.Tensor,
mu: torch.Tensor,
m: int,
dt: float,
f: list[Callable],
solver: str = "RK4",
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
"""Create a solution using the specified solver.
Args:
x0: initial condition,
mu: parameter of the equation,
m: number of time steps,
dt: time step,
f: list of functions defining the ODE,
solver: the solver to use (default: "RK4").
Returns:
A tuple containing:
- A tuple of tensors (t, mu, res) where:
- t: time steps,
- mu: parameters repeated for each time step,
- res: solution at each time step.
- The solution tensor res.
"""
# res shape : (nbr_traj, time, pos/mom)
# x0 shape : (nbr_traj, 2)
# mu shape : (nbr_traj, nbr_params)
# nbr_traj = nbr CI * nbr mu
t = torch.linspace(0, (m - 1) * dt, m)
# f[0]=DpH, f[1]=DqH
if solver == "RK4":
res = rk4(f, mu, x0, t)
mu_ = torch.repeat_interleave(mu, t.shape[0], dim=1)[..., None]
t_ = torch.repeat_interleave(t[None, :], x0.shape[0], dim=0)[..., None]
else:
mu_ = torch.repeat_interleave(mu, t.shape[0], dim=1)[..., None]
t_ = torch.repeat_interleave(t[None, :], x0.shape[0], dim=0)[..., None]
mu_ = torch.repeat_interleave(mu_[:, None, ...], x0.shape[1], dim=1)
t_ = torch.repeat_interleave(t_[:, None, ...], x0.shape[1], dim=2)
if solver == "Verlet_explicit":
res = verlet_explicit(f[0], f[1], mu, x0, t)
elif solver == "Verlet_implicit":
res = verlet_implicit(f[0], f[1], mu, x0, t)
elif solver == "Euler_symplectic":
res = euler_symplectic(f[0], f[1], mu, x0, t)
elif solver == "Euler_midpoint":
res = euler_midpoint(f[0], f[1], mu, x0, t)
elif solver == "RK4_symplectic":
res = rk4_symplectic(f[0], f[1], mu, x0, t)
else:
print("Erreur !")
return (t_[..., :-1, :], mu_[..., :-1, :], res[..., :-1, :]), res[..., 1:, :]