r"""Simplified Grad-Shafranov PINN on a D-shaped tokamak cross-section. Equation: .. math:: \Delta^* \psi = -(k_1^2 R^2 + k_2^2)(1 + \psi_n) \quad \text{in } \Omega \psi = 0 \quad \text{on } \partial\Omega where :math:`\Delta^* \psi = \partial_{RR}\psi + \partial_{ZZ}\psi - \tfrac{1}{R}\partial_R\psi` and :math:`\psi_n = (\psi - \psi_\mathrm{axis}) / (0 - \psi_\mathrm{axis})`. The domain is a :class:`~scimba_jax.domains.meshless_domains.domains_2d.Disk2D` with a D-shape tokamak mapping (ported from Nicolas Pailliez's ``new_scimba_nicolas`` branch). Training strategy (two phases): 1. **Adam** on a linearised problem (ψ_n = 0, i.e. pure Grad-Shafranov operator with constant source), to get a first approximation. 2. **Adam** with the full nonlinear source where ψ_axis is recomputed each step via :meth:`~GradShafranov2D.pre_computation_without_diff` and frozen (``stop_gradient``). """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.domain_mapping import DomainMapping from scimba_jax.domains.meshless_domains.domains_2d import Disk2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.elliptic_pde.grad_shafranov import GradShafranov2D from scimba_jax.plots.plots_nd import plot_abstract_approx_spaces # ── Parameters ───────────────────────────────────────────────────────────────── K1 = 1.05 K2 = 1.0 W_BC = 100.0 N_COLLOC = 4_000 N_BC_COLLOC = 2_000 N_EPOCHS_1 = 200 # phase 1 : source constante N_EPOCHS_2 = 200 # phase 2 : source nonlinéaire avec ψ_axis gelé HIDDEN = [20, 20] # ── D-shape tokamak mapping ──────────────────────────────────────────────────── _shape = { "amin": 0.6, "R_axis": 0.93, "Z_axis": 0.0, "GS_shift": 0.02, "ellip": 1.8, "rect_u": -0.2, "tria_u": 0.5, "rect_l": -0.2, "tria_l": 0.5, } def _mapping_tokamak(x: jnp.ndarray) -> jnp.ndarray: """D-shaped tokamak mapping: unit disk → (R, Z) cross-section. Called pointwise (x has shape (2,)) by DomainMapping's vmap. """ R_ax = _shape["R_axis"] - _shape["GS_shift"] Z0 = _shape["Z_axis"] x1, x2 = x[0], x[1] # single point, NOT x[:, 0] theta = jnp.arctan2(x2 - Z0, x1 - R_ax) r = jnp.sqrt((x1 - R_ax) ** 2 + (x2 - Z0) ** 2) R1 = ( r * jnp.cos( theta + _shape["tria_u"] * jnp.sin(theta) - _shape["rect_u"] * jnp.sin(2 * theta) ) + R_ax ) R2 = ( r * jnp.cos( theta + _shape["tria_l"] * jnp.sin(theta) - _shape["rect_l"] * jnp.sin(2 * theta) ) + R_ax ) R = jnp.where(theta < jnp.pi, R1, R2) Z = _shape["ellip"] * r * jnp.sin(theta) + Z0 return jnp.stack([R, Z]) mapping = DomainMapping(2, 2, _mapping_tokamak) # ── Domain ───────────────────────────────────────────────────────────────────── key = jax.random.PRNGKey(0) domain = Disk2D( center=[_shape["R_axis"] - _shape["GS_shift"], _shape["Z_axis"]], radius=_shape["amin"], is_main_domain=True, ) domain.set_mapping(mapping, bounds_postmap=[(0.25, 1.5), (-1.1, 1.1)]) sampler = TensorizedSampler([DomainSampler(domain)], bc=True) # ── Network & space ──────────────────────────────────────────────────────────── key, subkey = jax.random.split(key) nn = MLP(in_size=2, out_size=1, hidden_sizes=HIDDEN, key=subkey) space = ApproximationSpace({"x": 2}, [(nn, "scalar", None)], model_type="x") # ── Phase 1 : source constante (ψ_n = 0) ────────────────────────────────────── print("\n── Phase 1 : Adam, source constante (ψ_n = 0) ──────────────") model_phase1 = GradShafranov2D( domain, k1=K1, k2=K2, nonlinear=False, bc="weak", model_type="x" ) # Phase 1 : ψ_axis not yet available → pre_computation_without_diff returns {} # → résidu utilise source = (k1²R² + k2²)·1 (ψ_n absent) pinn1 = Projector( model_phase1, space, sampler, weights={"interior": [1.0], "boundary": [W_BC]}, optimizer="Adam", learning_rate=2e-2, ) key, sample_dict = sampler.sample(key, N_COLLOC, N_BC_COLLOC) print(f" initial loss: {pinn1.evaluate_loss(space, sample_dict):.4e}") t0 = timeit.default_timer() key, pinn1 = pinn1.project(key, space, N_EPOCHS_1, N_COLLOC, N_BC_COLLOC) print( f" best loss: {pinn1.best_loss['total']:.4e} | {timeit.default_timer() - t0:.1f}s" ) space_phase1 = pinn1.space # ── Phase 2 : source nonlinéaire avec ψ_axis gelé ───────────────────────────── print("\n── Phase 2 : Adam, source nonlinéaire (ψ_axis gelé) ──────────") model_phase2 = GradShafranov2D( domain, k1=K1, k2=K2, nonlinear=True, bc="weak", model_type="x" ) pinn2 = Projector( model_phase2, space_phase1, sampler, weights={"interior": [1.0], "boundary": [W_BC]}, ) key, sample_dict = sampler.sample(key, N_COLLOC, N_BC_COLLOC) print(f" initial loss: {pinn2.evaluate_loss(space_phase1, sample_dict):.4e}") t0 = timeit.default_timer() key, pinn2 = pinn2.project(key, space_phase1, N_EPOCHS_2, N_COLLOC, N_BC_COLLOC) print( f" best loss: {pinn2.best_loss['total']:.4e} | {timeit.default_timer() - t0:.1f}s" ) # ── Plots via scimba ─────────────────────────────────────────────────────────── plot_abstract_approx_spaces( (pinn1.space, pinn2.space), domain, draw_contours=True, n_drawn_contours=20, loss=(pinn1.losses, pinn2.losses), ) plt.savefig("grad_shafranov_2d.png", dpi=150, bbox_inches="tight") plt.show() print("Saved: grad_shafranov_2d.png")