r"""Learns the flow of a pendulum ODE using discrete flow networks (JAX version). Dynamics given by the Hamiltonian: H = p²/2 + μq²/2 + μ*0.012*q³/3 Training data generated with the explicit Verlet scheme. The following flows are compared: 1. A raw MLP flow. 2. An ExplicitEulerFlow (MLP inside). 3. A GSymplecticNet (structure-preserving). """ # %% import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.flow_approximation_spaces import ( FlowsApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.data_sampler import DataSampler from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.networks.discrete_ode_nets import ( ExplicitEulerFlow, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.networks.structure_preserving_nets.symplectic_nets import ( GSymplecticNet, ) from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.abstract_physical_model import ( PHYSICAL_RESIDUALS_TYPE, AbstractPhysicalModel, ) from scimba_jax.physical_models.abstract_residuals import PARAM_FUNC_TYPE, DataResidual jax.config.update("jax_enable_x64", True) # %% Hyperparameters key = jax.random.PRNGKey(0) dt = 0.02 N_simu = 500 Nt_train = 500 N_COLLOC = 1000 N_EPOCHS = 150 # %% Pendulum dynamics (explicit Verlet scheme) def dh_dq(q, mu): return mu * q + mu * 0.012 * q**2 def dh_dp(p, _mu): return p def verlet_step(q, p, mu): p_half = p - dt / 2 * dh_dq(q, mu) q_new = q + dt * dh_dp(p_half, mu) p_new = p_half - dt / 2 * dh_dq(q_new, mu) return q_new, p_new # %% Generate training data key, k1, k2, k3 = jax.random.split(key, 4) q0 = 1.0 + jax.random.uniform(k1, (N_simu,)) * 2.0 # q0 ∈ [1, 3] p0 = 0.2 * jax.random.uniform(k2, (N_simu,)) * 2.0 # p0 ∈ [0, 0.4] mu_simu = 0.8 + jax.random.uniform(k3, (N_simu,)) # mu ∈ [0.8, 1.8] def simulate_one(q0_i, p0_i, mu_i): def step_fn(carry, _): q, p = carry q_new, p_new = verlet_step(q, p, mu_i) return (q_new, p_new), (q, p, q_new, p_new) _, (q_t, p_t, q_next, p_next) = jax.lax.scan( step_fn, (q0_i, p0_i), None, length=Nt_train ) return q_t, p_t, q_next, p_next q_t, p_t, q_next, p_next = jax.vmap(simulate_one)(q0, p0, mu_simu) # shapes: (N_simu, Nt_train) x_data = jnp.stack([q_t.ravel(), p_t.ravel()], axis=-1) # (N, 2) mu_data = jnp.repeat(mu_simu, Nt_train)[:, None] # (N, 1) y_data = jnp.stack([q_next.ravel(), p_next.ravel()], axis=-1) # (N, 2) print(f"Training data: {x_data.shape[0]} pairs") # %% Reference trajectory for visualisation Nt_ref = int(400 / dt) q0ref, p0ref, mu_ref = 1.4, 0.12, 0.7 def simulate_ref(n_steps): def step_fn(carry, _): q, p = carry q_new, p_new = verlet_step(q, p, mu_ref) return (q_new, p_new), (q_new, p_new) _, (q_ref, p_ref) = jax.lax.scan(step_fn, (q0ref, p0ref), None, length=n_steps) return q_ref, p_ref q_ref, p_ref = simulate_ref(Nt_ref) # %% Infrastructure: domain, sampler, model, residual class ModelEmpty(AbstractPhysicalModel): def __init__(self, main_domain): super().__init__(main_domain=main_domain) self.physical_residuals: PHYSICAL_RESIDUALS_TYPE = {} class FlowDataResidual(DataResidual): def __init__(self, model_type: str = "x_mu"): super().__init__(size=2, model_type=model_type) def construct_residual(self, *vars: PARAM_FUNC_TYPE) -> PARAM_FUNC_TYPE: return vars[0] # forward flow variable domain_phase_space = [(-4.0, 4.0), (-4.0, 4.0)] domain_mu = [(0.8, 1.8)] dx = Square2D(domain_phase_space, is_main_domain=True) data_sampler = DataSampler((x_data, mu_data, y_data)) sampler = TensorizedSampler( [DomainSampler(dx), UniformParametricSampler(domain_mu)], bc=False, data_samplers={"data": data_sampler}, ) pde_model = ModelEmpty(main_domain=dx) pde_model.add_data_residual("data", FlowDataResidual(model_type="x_mu")) x0_ref = jnp.array([q0ref, p0ref]) mu0_ref = jnp.array([mu_ref]) # %% Case 1: raw MLP print("\n#1: raw MLP via FlowsApproximationSpace") key, k1 = jax.random.split(key) nn1 = MLP(in_size=3, out_size=2, hidden_sizes=[21, 21], key=k1) space1 = FlowsApproximationSpace( dim=2, params_dim=1, model=nn1, model_type="x_mu", rollout=1 ) print(f" ndof: {space1.compute_ndof()}") pinn1 = Projector(pde_model, space1, sampler) start = timeit.default_timer() key, pinn1 = pinn1.project(key, space1, N_EPOCHS, N_COLLOC) print( f" final loss: {pinn1.best_loss['total']:.4e} ({timeit.default_timer() - start:.1f}s)" ) space1 = pinn1.space traj1 = space1.rollout_trajectory(space1, x0_ref, mu0_ref, Nt_ref) # %% Case 2: ExplicitEulerFlow print("\n#2: ExplicitEulerFlow (MLP inside)") key, k2 = jax.random.split(key) inner_net2 = MLP(in_size=3, out_size=2, hidden_sizes=[21, 21], key=k2) euler_model = ExplicitEulerFlow(dim=2, flownet=inner_net2, dt=dt, time_dependent=False) space2 = FlowsApproximationSpace( dim=2, params_dim=1, model=euler_model, model_type="x_mu", rollout=1 ) print(f" ndof: {space2.compute_ndof()}") pinn2 = Projector(pde_model, space2, sampler) start = timeit.default_timer() key, pinn2 = pinn2.project(key, space2, N_EPOCHS, N_COLLOC) print( f" final loss: {pinn2.best_loss['total']:.4e} ({timeit.default_timer() - start:.1f}s)" ) space2 = pinn2.space traj2 = space2.rollout_trajectory(space2, x0_ref, mu0_ref, Nt_ref) # %% Case 3: GSymplecticNet print("\n#3: GSymplecticNet via FlowsApproximationSpace") key, k3 = jax.random.split(key) symp_net = GSymplecticNet(size=2, conditional_size=1, width=16, nb_layers=6, key=k3) space3 = FlowsApproximationSpace( dim=2, params_dim=1, model=symp_net, model_type="x_mu", rollout=1 ) print(f" ndof: {space3.compute_ndof()}") pinn3 = Projector(pde_model, space3, sampler) start = timeit.default_timer() key, pinn3 = pinn3.project(key, space3, N_EPOCHS, N_COLLOC) print( f" final loss: {pinn3.best_loss['total']:.4e} ({timeit.default_timer() - start:.1f}s)" ) space3 = pinn3.space traj3 = space3.rollout_trajectory(space3, x0_ref, mu0_ref, Nt_ref) # %% Plots t_axis = jnp.linspace(0, Nt_ref * dt, Nt_ref + 1) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) lh1 = pinn1.losses.losses_history lh2 = pinn2.losses.losses_history lh3 = pinn3.losses.losses_history ax = axes[0] ax.plot( q_ref, p_ref, "k-", label="Verlet ref", linewidth=2, marker="x", markevery=len(q_ref) // 50, markersize=7, markeredgewidth=1.5, ) ax.plot(traj1[:, 0], traj1[:, 1], "--", label="MLP") ax.plot(traj2[:, 0], traj2[:, 1], "--", label="ExplicitEuler") ax.plot(traj3[:, 0], traj3[:, 1], "--", label="GSymplecticNet") ax.set_xlabel("q") ax.set_ylabel("p") ax.set_title("Pendulum phase portrait") ax.legend() ax = axes[1] ax.semilogy(lh1["total"], label="MLP") ax.semilogy(lh2["total"], label="ExplicitEuler") ax.semilogy(lh3["total"], label="GSymplecticNet") ax.set_xlabel("epoch") ax.set_ylabel("loss") ax.set_title("Training loss") ax.legend() plt.tight_layout() plt.show() fig, axes = plt.subplots(1, 3, figsize=(15, 4)) labels = ["MLP", "ExplicitEuler", "GSymplecticNet"] for ax, traj, label in zip(axes, [traj1, traj2, traj3], labels): ax.plot(t_axis[: len(q_ref)], q_ref, "k-", label="ref") ax.plot(t_axis[: len(traj)], traj[:, 0], "--", label=label) ax.axvline(Nt_train * dt, color="gray", linestyle=":", label="train/extrap") ax.set_xlabel("t") ax.set_ylabel("q(t)") ax.set_title(f"q(t) — {label}") ax.legend() plt.tight_layout() plt.show() # %%