r"""Learns the flow of a pendulum ODE with multi-step rollout training (JAX version). Dynamics given by the Hamiltonian: H = p²/2 + μq²/2 + μ*0.012*q³/3 Training data generated with the explicit Verlet scheme. Each training pair is (x_t, x_{t + N_ROLLOUT*dt}) instead of (x_t, x_{t+dt}), and the model is trained to predict N_ROLLOUT steps ahead in one call. The following flows are compared: 1. A raw MLP flow. 2. An ExplicitEulerFlow (MLP inside). 3. A GSymplecticNet (structure-preserving). """ # %% import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.flow_approximation_spaces import ( FlowsApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.data_sampler import DataSampler from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.networks.discrete_ode_nets import ( ExplicitEulerFlow, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.networks.structure_preserving_nets.symplectic_nets import ( GSymplecticNet, ) from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.abstract_physical_model import ( PHYSICAL_RESIDUALS_TYPE, AbstractPhysicalModel, ) from scimba_jax.physical_models.abstract_residuals import PARAM_FUNC_TYPE, DataResidual jax.config.update("jax_enable_x64", True) # %% Hyperparameters key = jax.random.PRNGKey(0) dt = 0.02 L_STEPS = 5 # physical dt steps per model call (same for all) N_ROLLOUT_MLP = 1 # model calls composed in training loss — MLP N_ROLLOUT_EULER = 10 # model calls composed in training loss — ExplicitEuler N_ROLLOUT_SYMP = 1 # model calls composed in training loss — GSymplecticNet # training data stride: L_STEPS * N_ROLLOUT_x (same L for all, different N) N_simu = 500 Nt_train = 500 N_COLLOC = 1200 N_EPOCHS = 300 N_EPOCHS_SYM = 200 ENG_REGULARIZATION = 1e-5 # %% Pendulum dynamics (explicit Verlet scheme) def dh_dq(q, mu): return mu * q + mu * 0.012 * q**2 def dh_dp(p, _mu): return p def verlet_step(q, p, mu): p_half = p - dt / 2 * dh_dq(q, mu) q_new = q + dt * dh_dp(p_half, mu) p_new = p_half - dt / 2 * dh_dq(q_new, mu) return q_new, p_new # %% Generate training data key, k1, k2, k3 = jax.random.split(key, 4) q0 = 1.0 + jax.random.uniform(k1, (N_simu,)) * 2.0 # q0 ∈ [1, 3] p0 = 0.2 * jax.random.uniform(k2, (N_simu,)) * 2.0 # p0 ∈ [0, 0.4] mu_simu = 0.8 + jax.random.uniform(k3, (N_simu,)) # mu ∈ [0.8, 1.8] def make_training_data(n_rollout): """Generate pairs (x_t, x_{t + n_rollout*dt}) for all simulations.""" def simulate_one(q0_i, p0_i, mu_i): def step_fn(carry, _): q, p = carry q_new, p_new = verlet_step(q, p, mu_i) return (q_new, p_new), (q, p) _, (q_all, p_all) = jax.lax.scan( step_fn, (q0_i, p0_i), None, length=Nt_train + n_rollout ) q_start = q_all[:Nt_train] p_start = p_all[:Nt_train] q_target = q_all[n_rollout : n_rollout + Nt_train] p_target = p_all[n_rollout : n_rollout + Nt_train] return q_start, p_start, q_target, p_target q_t, p_t, q_next, p_next = jax.vmap(simulate_one)(q0, p0, mu_simu) x = jnp.stack([q_t.ravel(), p_t.ravel()], axis=-1) mu = jnp.repeat(mu_simu, Nt_train)[:, None] y = jnp.stack([q_next.ravel(), p_next.ravel()], axis=-1) return x, mu, y stride_mlp = L_STEPS * N_ROLLOUT_MLP stride_euler = L_STEPS * N_ROLLOUT_EULER stride_symp = L_STEPS * N_ROLLOUT_SYMP x_data_mlp, mu_data_mlp, y_data_mlp = make_training_data(stride_mlp) x_data_euler, mu_data_euler, y_data_euler = make_training_data(stride_euler) x_data_symp, mu_data_symp, y_data_symp = make_training_data(stride_symp) print( f"Training data MLP: {x_data_mlp.shape[0]} pairs (L={L_STEPS}, N={N_ROLLOUT_MLP}, stride={stride_mlp} = {stride_mlp * dt:.3f}s)" ) print( f"Training data Euler: {x_data_euler.shape[0]} pairs (L={L_STEPS}, N={N_ROLLOUT_EULER}, stride={stride_euler} = {stride_euler * dt:.3f}s)" ) print( f"Training data Symp: {x_data_symp.shape[0]} pairs (L={L_STEPS}, N={N_ROLLOUT_SYMP}, stride={stride_symp} = {stride_symp * dt:.3f}s)" ) # %% Reference trajectory for visualisation Nt_ref = int(500 / dt) q0ref, p0ref, mu_ref = 1.4, 0.12, 0.7 def simulate_ref(n_steps): def step_fn(carry, _): q, p = carry q_new, p_new = verlet_step(q, p, mu_ref) return (q_new, p_new), (q_new, p_new) _, (q_ref, p_ref) = jax.lax.scan(step_fn, (q0ref, p0ref), None, length=n_steps) return q_ref, p_ref q_ref, p_ref = simulate_ref(Nt_ref) # inference: 1 model call = L_STEPS*dt, so n_steps = Nt_ref // L_STEPS to cover Tf Nt_infer = Nt_ref // L_STEPS # same for all (L is identical) # %% Infrastructure: domain, sampler, model, residual class ModelEmpty(AbstractPhysicalModel): def __init__(self, main_domain): super().__init__(main_domain=main_domain) self.physical_residuals: PHYSICAL_RESIDUALS_TYPE = {} class FlowDataResidual(DataResidual): def __init__(self, model_type: str = "x_mu"): super().__init__(size=2, model_type=model_type) def construct_residual(self, *vars: PARAM_FUNC_TYPE) -> PARAM_FUNC_TYPE: return vars[0] # forward flow variable domain_phase_space = [(-4.0, 4.0), (-4.0, 4.0)] domain_mu = [(0.8, 1.8)] dx = Square2D(domain_phase_space, is_main_domain=True) def make_sampler(x, mu, y): return TensorizedSampler( [DomainSampler(dx), UniformParametricSampler(domain_mu)], bc=False, data_samplers={"data": DataSampler((x, mu, y))}, ) sampler_mlp = make_sampler(x_data_mlp, mu_data_mlp, y_data_mlp) sampler_euler = make_sampler(x_data_euler, mu_data_euler, y_data_euler) sampler_symp = make_sampler(x_data_symp, mu_data_symp, y_data_symp) pde_model = ModelEmpty(main_domain=dx) pde_model.add_data_residual("data", FlowDataResidual(model_type="x_mu")) x0_ref = jnp.array([q0ref, p0ref]) mu0_ref = jnp.array([mu_ref]) # %% Case 1: raw MLP print("\n#1: raw MLP via FlowsApproximationSpace") key, k1 = jax.random.split(key) nn1 = MLP(in_size=3, out_size=2, hidden_sizes=[21, 21], key=k1) space1 = FlowsApproximationSpace( dim=2, params_dim=1, model=nn1, model_type="x_mu", rollout=N_ROLLOUT_MLP ) print(f" ndof: {space1.compute_ndof()}") pinn1 = Projector(pde_model, space1, sampler_mlp) start = timeit.default_timer() key, pinn1 = pinn1.project( key, space1, N_EPOCHS, N_COLLOC, matrix_regularization=ENG_REGULARIZATION ) print( f" final loss: {pinn1.best_loss['total']:.4e} ({timeit.default_timer() - start:.1f}s)" ) space1 = pinn1.space traj1 = space1.rollout_trajectory(space1, x0_ref, mu0_ref, Nt_infer) # %% Case 2: ExplicitEulerFlow print("\n#2: ExplicitEulerFlow (MLP inside)") key, k2 = jax.random.split(key) inner_net2 = MLP(in_size=3, out_size=2, hidden_sizes=[21, 21], key=k2) euler_model = ExplicitEulerFlow(dim=2, flownet=inner_net2, dt=dt, time_dependent=False) space2 = FlowsApproximationSpace( dim=2, params_dim=1, model=euler_model, model_type="x_mu", rollout=N_ROLLOUT_EULER ) print(f" ndof: {space2.compute_ndof()}") pinn2 = Projector(pde_model, space2, sampler_euler) start = timeit.default_timer() key, pinn2 = pinn2.project( key, space2, N_EPOCHS, N_COLLOC, matrix_regularization=ENG_REGULARIZATION ) print( f" final loss: {pinn2.best_loss['total']:.4e} ({timeit.default_timer() - start:.1f}s)" ) space2 = pinn2.space traj2 = space2.rollout_trajectory(space2, x0_ref, mu0_ref, Nt_infer) # %% Case 3: GSymplecticNet print("\n#3: GSymplecticNet via FlowsApproximationSpace") key, k3 = jax.random.split(key) symp_net = GSymplecticNet(size=2, conditional_size=1, width=15, nb_layers=5, key=k3) space3 = FlowsApproximationSpace( dim=2, params_dim=1, model=symp_net, model_type="x_mu", rollout=N_ROLLOUT_SYMP ) print(f" ndof: {space3.compute_ndof()}") pinn3 = Projector(pde_model, space3, sampler_symp) start = timeit.default_timer() key, pinn3 = pinn3.project( key, space3, N_EPOCHS_SYM, N_COLLOC, matrix_regularization=ENG_REGULARIZATION ) print( f" final loss: {pinn3.best_loss['total']:.4e} ({timeit.default_timer() - start:.1f}s)" ) space3 = pinn3.space traj3 = space3.rollout_trajectory(space3, x0_ref, mu0_ref, Nt_infer) # %% Plots lh1 = pinn1.losses.losses_history lh2 = pinn2.losses.losses_history lh3 = pinn3.losses.losses_history t_axis_ref = jnp.linspace(0, Nt_ref * dt, Nt_ref + 1) t_axis = jnp.linspace(0, Nt_infer * L_STEPS * dt, Nt_infer + 1) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) ax = axes[0] ax.plot( q_ref, p_ref, "k-", label="Verlet ref", linewidth=2, marker="x", markevery=len(q_ref) // 50, markersize=7, markeredgewidth=1.5, ) ax.plot(traj1[:, 0], traj1[:, 1], "--", label="MLP") ax.plot(traj2[:, 0], traj2[:, 1], "--", label="ExplicitEuler") ax.plot(traj3[:, 0], traj3[:, 1], "--", label="GSymplecticNet") ax.set_xlabel("q") ax.set_ylabel("p") ax.set_title(f"Pendulum phase portrait (L={L_STEPS})") ax.legend() ax = axes[1] ax.semilogy(lh1["total"], label="MLP") ax.semilogy(lh2["total"], label="ExplicitEuler") ax.semilogy(lh3["total"], label="GSymplecticNet") ax.set_xlabel("epoch") ax.set_ylabel("loss") ax.set_title("Training loss") ax.legend() plt.tight_layout() plt.show() fig, axes = plt.subplots(1, 3, figsize=(15, 4)) _configs = [ ("MLP", traj1, N_ROLLOUT_MLP), ("ExplicitEuler", traj2, N_ROLLOUT_EULER), ("GSymplecticNet", traj3, N_ROLLOUT_SYMP), ] for ax, (label, traj, ro) in zip(axes, _configs): ax.plot(t_axis_ref[: len(q_ref)], q_ref, "k-", label="ref") ax.plot(t_axis[: len(traj)], traj[:, 0], "--o", markersize=3, label=label) ax.axvline(Nt_train * dt, color="gray", linestyle=":", label="train/extrap") ax.set_xlabel("t") ax.set_ylabel("q(t)") ax.set_title(f"q(t) — {label} (L={L_STEPS}, N={ro})") ax.legend() plt.tight_layout() plt.show() # %%