r"""Solves the Allen–Cahn equation in 1D using a PINN. .. math:: \partial_t u - 0.0001\,\partial_{xx} u + 5u^3 - 5u &= 0 \quad t \in (0,1),\; x \in (-1,1), \\ u(0, x) &= x^2 \cos(\pi x), \\ u(t,-1) &= u(t,1), \\ \partial_x u(t,-1) &= \partial_x u(t,1). The two periodic boundary conditions are enforced **by construction** via a periodic embedding on the spatial axis: the MLP receives :math:`[t,\cos(\pi x), \sin(\pi x), \ldots]` as input, so both :math:`u` and :math:`\partial_x u` are automatically :math:`2`-periodic. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.temporal_pde.allen_cahn import AllenCahnND from scimba_jax.plots.plots_nd import plot_abstract_approx_space jax.config.update("jax_enable_x64", True) # ── Problem parameters ──────────────────────────────────────────────────────── EPSILON_SQ = 0.0001 # diffusion coefficient ε² ALPHA = 5.0 # nonlinear coefficient (5u³ − 5u ≡ α(u³ − u)) domain_t = (0.0, 1.0) domain_x = [(-1.0, 1.0)] N_COLLOC = 8000 N_IC_COLLOC = 6000 N_EPOCHS = 400 # ── Initial condition u(0,x) = x² cos(πx) ─────────────────────────────────── def f_ic(x: jnp.ndarray) -> jnp.ndarray: return x**2 * jnp.cos(jnp.pi * x) # ── Domain & sampler ────────────────────────────────────────────────────────── dx = Segment1D(domain_x[0], is_main_domain=True) sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), DomainSampler(dx), ], model_type="t_x", bc=True, ic=True, ) # ── Physical model ──────────────────────────────────────────────────────────── model = AllenCahnND( main_domain=dx, time_domain=domain_t, epsilon_sq=EPSILON_SQ, alpha=ALPHA, bc="periodic", ic="weak", f_ic_rhs=f_ic, ) # ── Network & approximation space ───────────────────────────────────────────── key = jax.random.PRNGKey(42) # Periodic embedding on the spatial axis only (axis 1 in [t, x]). # Both u(t,±1) and ∂_x u(t,±1) are equal by construction. nn = MLP(in_size=2, out_size=1, hidden_sizes=[18] * 3, key=key) space = ApproximationSpace( {"x": 1}, [(nn, "scalar", None)], model_type="t_x", ) # ── Training ────────────────────────────────────────────────────────────────── print("Training Allen–Cahn PINN with ENG …") pinn = Projector( model, space, sampler, weights={"interior": [1.0], "boundary": [5.0], "ic interior": [5.0]}, ) start = timeit.default_timer() key, pinn = pinn.project(key, space, N_EPOCHS, N_COLLOC, N_COLLOC, N_IC_COLLOC) elapsed = timeit.default_timer() - start print(f"Best loss : {pinn.best_loss['total']:.4e}") print(f"Time : {elapsed:.1f}s ({N_EPOCHS} epochs)") # ── Slice plots ─────────────────────────────────────────────────────────────── T0, T1 = domain_t plot_abstract_approx_space( pinn.space, dx, time_domain=domain_t, time_values=[T0, T1 * 0.25, T1 * 0.5, T1 * 0.75, T1], loss=pinn.losses, residual=pinn.model, ) # ── Space–time heatmap ──────────────────────────────────────────────────────── N_PLOT = 300 x_vals = jnp.linspace(-1.0, 1.0, N_PLOT) t_vals = jnp.linspace(domain_t[0], domain_t[1], N_PLOT) xx, tt = jnp.meshgrid(x_vals, t_vals) x_flat = xx.reshape(-1, 1) t_flat = tt.reshape(-1, 1) vars_ = pinn.space.create_variables() u_fn = vars_[0] u_batched = jax.vmap(u_fn, in_axes=(None, 0, 0)) u_flat = u_batched(pinn.space, t_flat, x_flat) u_grid = np.array(u_flat).reshape(N_PLOT, N_PLOT) # (t, x) u_plot = u_grid.T # (x, t) — t on x-axis T = np.array(t_vals) X = np.array(x_vals) vmin, vmax = u_plot.min(), u_plot.max() fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # ── Filled contour + iso-contours ──────────────────────────────────────────── cf = axes[0].contourf(T, X, u_plot, levels=64, cmap="turbo", vmin=vmin, vmax=vmax) axes[0].contour(T, X, u_plot, levels=12, colors="w", linewidths=0.5, alpha=0.5) cbar = fig.colorbar(cf, ax=axes[0], pad=0.02) cbar.set_label(r"$u(t, x)$", fontsize=11) axes[0].set_xlabel(r"$t$", fontsize=12) axes[0].set_ylabel(r"$x$", fontsize=12) axes[0].set_title( rf"Allen–Cahn — $\varepsilon^2={EPSILON_SQ},\;\alpha={ALPHA}$", fontsize=12, ) # ── Loss history ────────────────────────────────────────────────────────────── colors = plt.cm.tab10.colors for idx, (label, values) in enumerate(pinn.losses.losses_history.items()): color = colors[idx % len(colors)] if values.ndim == 1 or values.shape[1] == 1: axes[1].semilogy(values.ravel(), label=label, color=color, linewidth=1.5) else: for i in range(values.shape[1]): axes[1].semilogy( values[:, i], label=f"{label}_{i}", color=colors[(idx + i) % len(colors)], linewidth=1.5, ) axes[1].set_xlabel("epoch", fontsize=12) axes[1].set_ylabel("loss", fontsize=12) axes[1].set_title("Training loss", fontsize=12) axes[1].legend(fontsize=9, framealpha=0.8) axes[1].grid(True, which="both", alpha=0.3, linestyle="--") plt.tight_layout() plt.show()