r"""Grey-Scott reaction-diffusion system in 2D with periodic BCs — time-marching. .. math:: \partial_t u &= \varepsilon_1 \Delta u + b_1(1 - u) - c_1 u v^2, \\ \partial_t v &= \varepsilon_2 \Delta v - b_2 v + c_2 u v^2, with :math:`t \in (0, T)`, :math:`(x, y) \in (-1, 1)^2`, periodic boundary conditions, and initial conditions .. math:: u_0(x, y) &= 1 - \exp\!\bigl(-10\,((x+0.05)^2 + (y+0.02)^2)\bigr), \\ v_0(x, y) &= 1 - \exp\!\bigl(-10\,((x-0.05)^2 + (y-0.02)^2)\bigr). The time domain is split into ``N_WINDOWS`` equal intervals. Each interval has its own PINN; the IC of window k+1 is the solution of window k at its final time. Periodic BCs are enforced **by construction** via a Fourier embedding on both spatial axes: the MLP receives :math:`[t, \cos(\pi x), \sin(\pi x), \cos(\pi y), \sin(\pi y)]` as input. """ import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.multi_time_projector import ( MultiTimeProjector, ) from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.temporal_pde.grey_scott import GreyScottND jax.config.update("jax_enable_x64", True) # ── Problem parameters ──────────────────────────────────────────────────────── EPS1 = 0.2 # D_u EPS2 = 0.1 # D_v B1 = 40.0 # feed rate (= F in standard Grey-Scott) B2 = 100.0 # kill+feed (= F + k) C1 = 1000.0 # reaction coefficient for u C2 = 1000.0 # reaction coefficient for v N_COLLOC = 6000 N_IC_COLLOC = 4000 N_EPOCHS = 800 # ── Spatial domain ──────────────────────────────────────────────────────────── domain_x = Square2D([(-1.0, 1.0), (-1.0, 1.0)], is_main_domain=True) # ── Initial condition (window 0) ────────────────────────────────────────────── def f_ic_analytical(xy: jnp.ndarray) -> jnp.ndarray: x, y = xy[0], xy[1] u0 = 1.0 - jnp.exp(-10.0 * ((x + 0.05) ** 2 + (y + 0.02) ** 2)) v0 = jnp.exp(-10.0 * ((x - 0.05) ** 2 + (y - 0.02) ** 2)) return jnp.stack([u0, v0]) # ── Model factory (PDE params captured via closure) ─────────────────────────── def model_factory(t_start: float, t_end: float, f_ic_rhs): return GreyScottND( main_domain=domain_x, time_domain=(t_start, t_end), D_u=EPS1, D_v=EPS2, F=B1, k=B2 - B1, c_u=C1, c_v=C2, mode="single", bc=None, ic="weak", f_ic_rhs=f_ic_rhs, ) # ── Spaces and sampler (built once, sampler time domain updated per window) ─── N_WINDOWS = 7 t_windows = [ (0.0, 0.1), (0.1, 0.2), (0.2, 0.3), (0.3, 0.4), (0.4, 0.5), (0.5, 0.6), (0.6, 0.7), ] key = jax.random.PRNGKey(42) spaces = [] for _ in range(N_WINDOWS): key, subkey = jax.random.split(key) nn = MLP( in_size=3, out_size=2, hidden_sizes=[12] * 5, activation="silu", key=subkey, embedding="periodic", periods=(2.0, 2.0), embedding_axes=(1, 2), ) spaces.append( ApproximationSpace({"t": 1, "x": 2}, [(nn, "field", 2)], model_type="t_x") ) sampler = TensorizedSampler( [ UniformTimeSampler(t_windows[0]), # bornes remplacées par MultiTimeProjector DomainSampler(domain_x), ], model_type="t_x", bc=False, ic=True, ) # ── Projector factory (optimizer, poids définis ici) ───────────────────────── def projector_factory(model, space, sampler): return Projector( model, space, sampler, weights={"interior": [0.01, 0.01], "ic interior": [10.0, 20.0]}, matrix_regularization=1.0e-6, ) # ── Time-marching training ──────────────────────────────────────────────────── mtp = MultiTimeProjector( windows=t_windows, model_factory=model_factory, projector_factory=projector_factory, spaces=spaces, sampler=sampler, f_ic_0=f_ic_analytical, key=key, # n_epochs=[ # 10, # 8, # 7, # 7, # 7, # 7, # 7, # ], # epochs par fenêtre (peut être un int ou une liste) n_epochs=[ 1000, 800, 700, 700, 700, 700, 700, ], # epochs par fenêtre (peut être un int ou une liste) n_colloc=[ 5000, 5000, 6000, 6000, 6000, 6000, 6000, ], # colloc par fenêtre (int ou liste) n_ic_colloc=N_IC_COLLOC, ) # ── Plotting helpers ────────────────────────────────────────────────────────── N_PLOT = 200 x_vals = jnp.linspace(-1.0, 1.0, N_PLOT) y_vals = jnp.linspace(-1.0, 1.0, N_PLOT) xx, yy = jnp.meshgrid(x_vals, y_vals) xy_flat = jnp.stack([xx.reshape(-1), yy.reshape(-1)], axis=-1) X = np.array(x_vals) Y = np.array(y_vals) def eval_at(space, t_val: float): vars_ = space.create_variables() uv_fn = vars_[0] uv_batched = jax.vmap(uv_fn, in_axes=(None, 0, 0)) t_flat = jnp.full((N_PLOT * N_PLOT, 1), t_val) uv_flat = uv_batched(space, t_flat, xy_flat) u = np.array(uv_flat[:, 0]).reshape(N_PLOT, N_PLOT) v = np.array(uv_flat[:, 1]).reshape(N_PLOT, N_PLOT) return u, v # ── Référence : solveur différences finies (RK4 + Laplacien périodique) ─────── N_FD = 240 dx_fd = 2.0 / N_FD x_fd = jnp.linspace(-1.0 + dx_fd / 2, 1.0 - dx_fd / 2, N_FD) xx_fd, yy_fd = jnp.meshgrid(x_fd, x_fd) u_fd0 = 1.0 - jnp.exp(-10.0 * ((xx_fd + 0.05) ** 2 + (yy_fd + 0.02) ** 2)) v_fd0 = jnp.exp(-10.0 * ((xx_fd - 0.05) ** 2 + (yy_fd - 0.02) ** 2)) def lap(f): return ( jnp.roll(f, 1, 0) + jnp.roll(f, -1, 0) + jnp.roll(f, 1, 1) + jnp.roll(f, -1, 1) - 4.0 * f ) / dx_fd**2 def gs_rhs_fd(u, v): uv2 = u * v**2 return ( EPS1 * lap(u) + B1 * (1.0 - u) - C1 * uv2, EPS2 * lap(v) - B2 * v + C2 * uv2, ) @jax.jit def rk4_step_fd(u, v, dt): k1u, k1v = gs_rhs_fd(u, v) k2u, k2v = gs_rhs_fd(u + 0.5 * dt * k1u, v + 0.5 * dt * k1v) k3u, k3v = gs_rhs_fd(u + 0.5 * dt * k2u, v + 0.5 * dt * k2v) k4u, k4v = gs_rhs_fd(u + dt * k3u, v + dt * k3v) return ( u + dt / 6.0 * (k1u + 2 * k2u + 2 * k3u + k4u), v + dt / 6.0 * (k1v + 2 * k2v + 2 * k3v + k4v), ) DT_FD = 1e-4 N_TIMES = 5 t_snapshots = np.linspace(0.0, t_windows[-1][1], N_TIMES + 1) print("Solving Grey-Scott with finite differences …") u_fd, v_fd = u_fd0, v_fd0 fd_snaps = {} # t_val → (u, v) fd_snaps[0.0] = (np.array(u_fd), np.array(v_fd)) t_now = 0.0 for t_target in t_snapshots[1:]: n_steps = int(round((t_target - t_now) / DT_FD)) for _ in range(n_steps): u_fd, v_fd = rk4_step_fd(u_fd, v_fd, DT_FD) t_now = t_target fd_snaps[round(t_target, 8)] = (np.array(u_fd), np.array(v_fd)) print("FD done.") X_fd = np.array(x_fd) # ── eval_at sur la grille FD (pour les erreurs) ─────────────────────────────── xy_flat_fd = jnp.stack([xx_fd.reshape(-1), yy_fd.reshape(-1)], axis=-1) def eval_at_fd_grid(space, t_val: float): vars_ = space.create_variables() uv_fn = vars_[0] uv_batched = jax.vmap(uv_fn, in_axes=(None, 0, 0)) t_flat = jnp.full((N_FD * N_FD, 1), t_val) uv_flat = uv_batched(space, t_flat, xy_flat_fd) u = np.array(uv_flat[:, 0]).reshape(N_FD, N_FD) v = np.array(uv_flat[:, 1]).reshape(N_FD, N_FD) return u, v # ── Plot PINN vs FD : 6 lignes × (N_TIMES+1) colonnes ──────────────────────── n_cols = N_TIMES + 1 fig, axes = plt.subplots(6, n_cols, figsize=(4 * n_cols, 20)) row_labels = [ r"$u$ PINN", r"$v$ PINN", r"$u$ FD", r"$v$ FD", r"$|u_\mathrm{PINN} - u_\mathrm{FD}|$", r"$|v_\mathrm{PINN} - v_\mathrm{FD}|$", ] for i, t_val in enumerate(t_snapshots): u_pinn, v_pinn = eval_at(mtp.which_space(float(t_val)), float(t_val)) u_ref, v_ref = fd_snaps[round(t_val, 8)] u_pinn_fd, v_pinn_fd = eval_at_fd_grid(mtp.which_space(float(t_val)), float(t_val)) err_u = np.abs(u_pinn_fd - u_ref) err_v = np.abs(v_pinn_fd - v_ref) for row, (grid, gx, cmap) in enumerate( [ (u_pinn, X, "turbo"), (v_pinn, X, "turbo"), (u_ref, X_fd, "turbo"), (v_ref, X_fd, "turbo"), (err_u, X_fd, "turbo"), (err_v, X_fd, "turbo"), ] ): ax = axes[row, i] im = ax.pcolormesh(gx, gx, grid, cmap=cmap, shading="auto") if row < 4: ax.contour( gx, gx, grid, levels=8, colors="white", linewidths=0.4, alpha=0.5 ) fig.colorbar(im, ax=ax, pad=0.02) if row == 0: ax.set_title(rf"$t={t_val:.3f}$", fontsize=11) ax.set_xlabel(r"$x$") ax.set_aspect("equal") for row, label in enumerate(row_labels): axes[row, 0].set_ylabel(label) fig.suptitle( rf"Grey-Scott 2D — PINN ({N_WINDOWS} fenêtres) vs FD (RK4, N={N_FD}) — " rf"$\varepsilon_1={EPS1}$, $\varepsilon_2={EPS2}$, " rf"$b_1={B1}$, $b_2={B2}$, $c_1=c_2={C1:.0f}$", fontsize=10, ) plt.tight_layout() plt.show() # ── Plot loss history per window ────────────────────────────────────────────── fig_loss, axes_loss = plt.subplots( 1, N_WINDOWS, figsize=(5 * N_WINDOWS, 4), sharey=True, squeeze=False ) axes_loss = axes_loss[0] for w, pinn in enumerate(mtp.projectors): ax = axes_loss[w] history = pinn.losses.losses_history for loss_name, values in history.items(): arr = np.array(values) if arr.ndim == 1 or arr.shape[1] == 1: lw = 2 if loss_name == "total" else 1 ax.semilogy(arr, label=loss_name, linewidth=lw) else: for j in range(arr.shape[1]): ax.semilogy(arr[:, j], label=f"{loss_name}_{j}", linewidth=1) ws, we = mtp.windows[w] ax.set_title(rf"Window {w + 1} $t \in [{ws:.2f}, {we:.2f}]$") ax.set_xlabel("Epoch") if w == 0: ax.set_ylabel("Loss") ax.legend(fontsize=8) ax.grid(True, which="both", alpha=0.3) plt.tight_layout() plt.show() # ── Continuité aux jonctions : t_w⁻ vs t_w⁺ ────────────────────────────────── if N_WINDOWS > 1: eps = 1e-4 n_junctions = N_WINDOWS - 1 fig_junc, axes_junc = plt.subplots( 4, n_junctions, figsize=(5 * n_junctions, 16), squeeze=False ) for j in range(n_junctions): t_junc = mtp.windows[j][1] # fin fenêtre j = début fenêtre j+1 u_m, v_m = eval_at(mtp.spaces[j], t_junc - eps) u_p, v_p = eval_at(mtp.spaces[j + 1], t_junc + eps) for row, (grid, label) in enumerate( [ (u_m, rf"$u(t={t_junc:.3f}^-)$"), (u_p, rf"$u(t={t_junc:.3f}^+)$"), (v_m, rf"$v(t={t_junc:.3f}^-)$"), (v_p, rf"$v(t={t_junc:.3f}^+)$"), ] ): ax = axes_junc[row, j] im = ax.pcolormesh(X, Y, grid, cmap="turbo", shading="auto") ax.contour(X, Y, grid, levels=6, colors="white", linewidths=0.4, alpha=0.5) fig_junc.colorbar(im, ax=ax, pad=0.02) ax.set_title(label, fontsize=11) ax.set_xlabel(r"$x$") ax.set_ylabel(r"$y$" if j == 0 else "") ax.set_aspect("equal") diff_u = np.abs(u_p - u_m) diff_v = np.abs(v_p - v_m) print( f"Jonction t={t_junc:.3f} — |u⁺−u⁻| max={diff_u.max():.2e} " f"|v⁺−v⁻| max={diff_v.max():.2e}" ) fig_junc.suptitle("Continuité aux jonctions (t⁻ vs t⁺)", fontsize=12) plt.tight_layout() plt.show()