r"""Poisson equation −Δu = 1 with homogeneous Dirichlet BCs on 7 fancy domains. Solved with a PINN (weak BC formulation) on each of the level-set domains defined in :mod:`scimba_jax.domains.meshless_domains.examples_domains`: Flower, Heart, Squircle, Star, Annulus, Batman. The final figure shows all 6 solutions side by side. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from scimba_jax.domains.meshless_domains.examples_domains import ( Annulus2D, Batman2D, Flower2D, Heart2D, Squircle2D, Star2D, ) from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.elliptic_pde.laplacians import LaplacianDirichletND # ── Configuration ────────────────────────────────────────────────────────────── N_COLLOC = 5000 N_BC_COLLOC = 4000 HIDDEN = [24, 24] W_BC = 100.0 # Epochs par domaine (même ordre que DOMAINS) ; None → N_EPOCHS_DEFAULT. N_EPOCHS_DEFAULT = 500 N_EPOCHS_LIST = [50, 80, 200, 500, 600, 1000] # source f = 1 partout def f_rhs(xy: jnp.ndarray) -> jnp.ndarray: """Uniform source term.""" return jnp.array([1.0]) # ── Domaines à traiter ───────────────────────────────────────────────────────── DOMAINS = [ ("Squircle", Squircle2D(a=1.0, b=1.0, p=4.0, is_main_domain=True)), ("Annulus", Annulus2D(r_inner=0.4, r_outer=1.0, is_main_domain=True)), ("Flower", Flower2D(R=1.0, a=0.3, n=3, is_main_domain=True)), ("Heart", Heart2D(R=1.0, p=1.0, is_main_domain=True)), ("Star", Star2D(R=1.0, a=0.5, n=3, is_main_domain=True)), ("Batman", Batman2D(is_main_domain=True)), ] # ── Entraînement ─────────────────────────────────────────────────────────────── key = jax.random.PRNGKey(0) results = [] # list of (name, domain, trained_pinn) assert len(N_EPOCHS_LIST) == len(DOMAINS), ( "N_EPOCHS_LIST doit avoir autant d'entrées que DOMAINS" ) for (name, domain), n_ep in zip(DOMAINS, N_EPOCHS_LIST): n_ep = n_ep if n_ep is not None else N_EPOCHS_DEFAULT print(f"\n── {name} ({n_ep} epochs) ──────────────────────────────") model = LaplacianDirichletND( domain, f_rhs=lambda xy: f_rhs(xy), bc="weak", model_type="x", ) weights = {"interior": [1.0], "boundary": [W_BC]} sampler = TensorizedSampler([DomainSampler(domain)], bc=True) key, subkey = jax.random.split(key) nn = MLP(in_size=2, out_size=1, hidden_sizes=HIDDEN, key=subkey) space = ApproximationSpace({"x": 2}, [(nn, "scalar", None)], model_type="x") pinn = Projector(model, space, sampler, weights=weights, matrix_regularization=5e-7) t0 = timeit.default_timer() key, pinn = pinn.project(key, space, n_ep, N_COLLOC, N_BC_COLLOC) dt = timeit.default_timer() - t0 print(f" best loss = {pinn.best_loss['total']:.3e} | {dt:.1f}s") results.append((name, domain, pinn)) # ── Plots ────────────────────────────────────────────────────────────────────── N_GRID = 200 n_domains = len(results) # 6 ncols = 3 nrows = 2 fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 4.5 * nrows)) axes_flat = axes.ravel() for ax, (name, domain, pinn) in zip(axes_flat, results): # grille sur la bounding box (x0, x1), (y0, y1) = domain.bounds[0], domain.bounds[1] gx = np.linspace(float(x0), float(x1), N_GRID) gy = np.linspace(float(y0), float(y1), N_GRID) GX, GY = np.meshgrid(gx, gy) pts = jnp.array(np.stack([GX.ravel(), GY.ravel()], axis=1)) # évaluer u u_var = pinn.space.create_variables()[0] u_vals = jax.device_get( u_var.vmap_on_physical_variables()(pinn.space, pts) ).reshape(N_GRID, N_GRID) # masque : points hors du domaine → NaN inside = jax.device_get(domain.is_inside(pts)).reshape(N_GRID, N_GRID) u_masked = np.where(inside, np.array(u_vals), np.nan) vmax = float(np.nanmax(u_masked)) im = ax.pcolormesh( GX, GY, u_masked, cmap="inferno", shading="gouraud", vmin=0.0, vmax=vmax ) ax.contour( GX, GY, np.where(inside, np.array(u_vals), 0.0), levels=10, colors="w", linewidths=0.5, alpha=0.6, ) fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04) ax.set_title(f"{name}\nloss={pinn.best_loss['total']:.2e}", fontsize=10) ax.set_aspect("equal") ax.axis("off") # masquer les axes vides for ax in axes_flat[n_domains:]: ax.set_visible(False) fig.suptitle(r"$-\Delta u = 1$, Dirichlet homogène faible — 6 domaines", fontsize=13) plt.tight_layout() plt.savefig("laplacian_example_domains.png", dpi=150, bbox_inches="tight") plt.show() print("Saved: laplacian_example_domains.png")