r"""Solves the Korteweg–De Vries (KdV) equation in 1D using a PINN. Next, we explore the one-dimensional Korteweg–De Vries (KdV) equation, a fundamental model used to describe the dynamics of solitary waves, or solitons. The KdV equation is expressed as follows: .. math:: \partial_t u + \eta u \partial_x u + \mu^2 \partial_{xxx} u &= 0 \quad t \in (0, 1),\; x \in (-1, 1), \\ u(x, 0) &= \cos(\pi x), \\ u(t, -1) &= u(t, 1). where :math:`u: (-1,1) \times (0,1) \to \mathbb{R}` is the unknown scalar field, :math:`\eta` is the nonlinear advection coefficient, and :math:`\mu^2` is the dispersion coefficient. The periodic boundary condition :math:`u(t,-1) = u(t,1)` is enforced as a soft penalty on both endpoints. The initial condition is a cosine. The solution is trained with the Adam optimizer and visualized as a space–time heatmap. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.temporal_pde.kdv import KdV1D from scimba_jax.plots.plots_nd import plot_abstract_approx_space jax.config.update("jax_enable_x64", True) # ── Problem parameters ──────────────────────────────────────────────────────── ETA = 1.0 # nonlinear advection coefficient η MU_SQ = 0.022**2.0 # dispersion coefficient µ² domain_t = (0.0, 1.0) domain_x = [(-1.0, 1.0)] N_COLLOC = 6000 N_IC_COLLOC = 4000 N_EPOCHS = 700 # ── Initial condition ───────────────────────────────────────────────────────── def f_ic(x: jnp.ndarray) -> jnp.ndarray: return jnp.cos(jnp.pi * x) # ── Domain & sampler ────────────────────────────────────────────────────────── dx = Segment1D(domain_x[0], is_main_domain=True) sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), DomainSampler(dx), ], model_type="t_x", bc=False, # no explicit BC — periodic structure comes from IC + interior ic=True, ) # ── Physical model ──────────────────────────────────────────────────────────── # For periodic BC u(t,-1) = u(t,1) a fully correct treatment requires a custom # residual sampling paired boundary points. Here we omit explicit boundary # enforcement: the IC (cos(πx)) and the KdV dynamics enforce periodicity # implicitly over the short time window [0,1]. model = KdV1D( main_domain=dx, time_domain=domain_t, alpha=ETA, delta=MU_SQ, bc=None, ic="weak", f_ic_rhs=f_ic, ) # ── Network & approximation space ───────────────────────────────────────────── key = jax.random.PRNGKey(42) # Periodic embedding on the spatial axis (axis 1 in [t, x]): # x -> [cos(πx), sin(πx)], giving in_size=2 but embedding_out_size=3 (t + 2 features). nn = MLP( in_size=2, out_size=1, hidden_sizes=[16, 16, 16], key=key, embedding="periodic", embedding_axes=(1,), # only x is periodic, not t periods=(2.0,), # domain length = 2 → period = 2 n_periodic_features=4, ) space = ApproximationSpace( {"x": 1}, [(nn, "scalar", None)], model_type="t_x", ) # ── Training ────────────────────────────────────────────────────────────────── print("Training KdV PINN with ENG …") pinn = Projector( model, space, sampler, optimizer="ENG", weights={"interior": [1.0], "ic interior": [5.0]}, ) start = timeit.default_timer() key, pinn = pinn.project(key, space, N_EPOCHS, N_COLLOC, N_COLLOC, N_IC_COLLOC) elapsed = timeit.default_timer() - start print(f"Best loss : {pinn.best_loss['total']:.4e}") print(f"Time : {elapsed:.1f}s ({N_EPOCHS} epochs)") # ── Slice plots (t = 0, 0.5, 1) via standard utility ───────────────────────── plot_abstract_approx_space( pinn.space, dx, time_domain=domain_t, time_values=[0.0, 0.25, 0.5, 0.75, 1.0], loss=pinn.losses, residual=pinn.model, ) # ── Space–time heatmap ──────────────────────────────────────────────────────── N_PLOT = 300 x_vals = jnp.linspace(-1.0, 1.0, N_PLOT) t_vals = jnp.linspace(0.0, 1.0, N_PLOT) xx, tt = jnp.meshgrid(x_vals, t_vals) # rows=t, cols=x x_flat = xx.reshape(-1, 1) t_flat = tt.reshape(-1, 1) vars_ = pinn.space.create_variables() u_fn = vars_[0] u_batched = jax.vmap(u_fn, in_axes=(None, 0, 0)) u_flat = u_batched(pinn.space, t_flat, x_flat) u_grid = np.array(u_flat).reshape(N_PLOT, N_PLOT) # (t, x) # transpose: t on x-axis, x on y-axis u_plot = u_grid.T # (x, t) T = np.array(t_vals) X = np.array(x_vals) vmin, vmax = u_plot.min(), u_plot.max() fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # ── Filled contour + iso-contour lines ─────────────────────────────────────── cf = axes[0].contourf(T, X, u_plot, levels=64, cmap="turbo", vmin=vmin, vmax=vmax) axes[0].contour(T, X, u_plot, levels=12, colors="w", linewidths=0.5, alpha=0.5) cbar = fig.colorbar(cf, ax=axes[0], pad=0.02) cbar.set_label(r"$u(t, x)$", fontsize=11) axes[0].set_xlabel(r"$t$", fontsize=12) axes[0].set_ylabel(r"$x$", fontsize=12) axes[0].set_title(rf"KdV — $\eta={ETA},\;\mu^2={MU_SQ}$", fontsize=12) # ── Loss history ────────────────────────────────────────────────────────────── colors = plt.cm.tab10.colors for idx, (label, values) in enumerate(pinn.losses.losses_history.items()): color = colors[idx % len(colors)] if values.ndim == 1 or values.shape[1] == 1: axes[1].semilogy(values.ravel(), label=label, color=color, linewidth=1.5) else: for i in range(values.shape[1]): axes[1].semilogy( values[:, i], label=f"{label}_{i}", color=colors[(idx + i) % len(colors)], linewidth=1.5, ) axes[1].set_xlabel("epoch", fontsize=12) axes[1].set_ylabel("loss", fontsize=12) axes[1].set_title("Training loss", fontsize=12) axes[1].legend(fontsize=9, framealpha=0.8) axes[1].grid(True, which="both", alpha=0.3, linestyle="--") plt.tight_layout() plt.show()