r"""Visualisation du sampling de domaines tokamak (2D et 3D). Ce script illustre les trois samplers de :mod:`scimba_jax.domains.tokamak.tokamak_samplers`: * :class:`~scimba_jax.domains.tokamak.TokamakSampler2D` — coupe poloidale (R, Z) * :class:`~scimba_jax.domains.tokamak.TokamakSampler3DBox` — volume toroïdal (R, Z, φ) par rejet dans une boîte 3D * :class:`~scimba_jax.domains.tokamak.TokamakSampler3DFrom2D` — volume toroïdal (R, Z, φ) avec jacobien toroïdal correct Les données EQDSK réelles pour MAST et JET peuvent être fournies via ``EQDSK_FILES`` ci-dessous. Si un fichier est manquant le script bascule sur une géométrie synthétique (ellipse) afin de pouvoir tourner sans données. Fichiers attendus (voir ``src/scimba_jax/domains/tokamak/data/README.md``):: data/g_p49320_t0.60000 ← MAST data/eqdsk_mast_4.dat ← MAST (alternate) data/eqdsk_jet_compare.dat ← JET Usage:: python examples/examples_jax/geometry/tokamak_sampling_plot.py """ from __future__ import annotations from pathlib import Path import jax import matplotlib.pyplot as plt import numpy as np jax.config.update("jax_enable_x64", True) from scimba_jax.domains.tokamak import ( # noqa: E402 TokamakSampler2D, TokamakSampler3DBox, TokamakSampler3DFrom2D, read_eqdsk_wall, ) # ── Configuration ────────────────────────────────────────────────────────────── N_INT = 5_000 # interior points per plot N_BC = 800 # boundary points per plot KEY = jax.random.PRNGKey(0) _DATA = ( Path(__file__).parent.parent.parent.parent / "src/scimba_jax/domains/tokamak/data" ) # All available EQDSK files. Adjust names/paths as needed. EQDSK_FILES: dict[str, Path] = { "MAST (g_p49320, t=0.60s)": _DATA / "g_p49320_t0.60000", "MAST (g_p44849, t=0.20s)": _DATA / "g_p44849_t0.20000", "MAST (eqdsk_mast_4)": _DATA / "eqdsk_mast_4.dat", "MAST-U (eqdsk_mast-u)": _DATA / "eqdsk_mast-u.dat", "JET (eqdsk_jet_compare)": _DATA / "eqdsk_jet_compare.dat", } # ── Synthetic wall fallback ──────────────────────────────────────────────────── def _synthetic_wall( R0: float = 3.0, a: float = 0.9, kappa: float = 1.7, delta: float = 0.3, n: int = 300, ) -> tuple[np.ndarray, np.ndarray]: """D-shaped tokamak cross-section (MHD parametrisation).""" t = np.linspace(0, 2 * np.pi, n, endpoint=False) R = R0 + a * np.cos(t + delta * np.sin(t)) Z = kappa * a * np.sin(t) return R, Z # ── Load walls ───────────────────────────────────────────────────────────────── walls: dict[str, tuple[np.ndarray, np.ndarray]] = {} for name, path in EQDSK_FILES.items(): if path.exists(): try: R_w, Z_w = read_eqdsk_wall(str(path)) walls[name] = (R_w, Z_w) print(f"[OK] Loaded {name} from {path.name}") except Exception as exc: print(f"[WARN] Could not read {path.name}: {exc}") else: print(f"[SKIP] {path.name} not found") # Always add synthetic fallbacks so the script is always runnable walls["D-shape (synthetic)"] = _synthetic_wall(R0=3.0, a=0.9, kappa=1.7, delta=0.3) walls["Circular (synthetic)"] = _synthetic_wall(R0=2.0, a=0.5, kappa=1.0, delta=0.0) print(f"\nWalls loaded: {list(walls.keys())}\n") # ── 2D plots ─────────────────────────────────────────────────────────────────── def plot_2d(walls: dict, key: jax.Array) -> jax.Array: """Plot 2D poloidal cross-section sampling for every wall.""" n = len(walls) ncols = min(3, n) nrows = (n + ncols - 1) // ncols fig, axes = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows)) fig.suptitle("Sampling 2D — coupe poloidale (R, Z)", fontsize=13) axes_flat = np.array(axes).ravel() for ax, (name, (R_wall, Z_wall)) in zip(axes_flat, walls.items()): sampler = TokamakSampler2D(R_wall, Z_wall, oversample=5) key, sample = sampler.sample(key, n=N_INT, n_bc=N_BC) x_int = sample["interior"][0] x_bc, nrm = sample["boundary"] ax.scatter( np.array(x_int[:, 0]), np.array(x_int[:, 1]), s=0.5, c="steelblue", alpha=0.4, rasterized=True, label="interior", ) ax.scatter( np.array(x_bc[:, 0]), np.array(x_bc[:, 1]), s=3, c="crimson", zorder=3, label="boundary", ) # wall contour R_c = np.append(R_wall, R_wall[0]) Z_c = np.append(Z_wall, Z_wall[0]) ax.plot(R_c, Z_c, "k-", lw=1.0) # normals (subsample) sk = max(1, N_BC // 30) ax.quiver( np.array(x_bc[::sk, 0]), np.array(x_bc[::sk, 1]), np.array(nrm[::sk, 0]), np.array(nrm[::sk, 1]), scale=20, width=0.004, color="orange", alpha=0.8, ) ax.set_title(name, fontsize=9) ax.set_xlabel("R [m]") ax.set_ylabel("Z [m]") ax.set_aspect("equal") if ax is axes_flat[0]: ax.legend(markerscale=5, fontsize=7) for ax in axes_flat[len(walls) :]: ax.set_visible(False) plt.tight_layout() return key # ── 3D plots ─────────────────────────────────────────────────────────────────── def _to_xyz(R: np.ndarray, Z: np.ndarray, phi: np.ndarray) -> tuple: x = R * np.cos(phi) y = R * np.sin(phi) return x, y, Z def plot_3d_one( name: str, R_wall: np.ndarray, Z_wall: np.ndarray, key: jax.Array, fig: plt.Figure, axes_box: plt.Axes, axes_from2d: plt.Axes, axes_rz_box: plt.Axes, axes_rz_from2d: plt.Axes, ) -> jax.Array: """Plot 3D toroidal sampling for a single wall geometry.""" # 3D box sampler s3b = TokamakSampler3DBox(R_wall, Z_wall, oversample=5) key, d3b = s3b.sample(key, n=N_INT) x3b = np.array(d3b["interior"][0]) xb, yb, zb = _to_xyz(x3b[:, 0], x3b[:, 1], x3b[:, 2]) # 3D from-2D sampler s3f = TokamakSampler3DFrom2D(R_wall, Z_wall, oversample=5) key, d3f = s3f.sample(key, n=N_INT) x3f = np.array(d3f["interior"][0]) xf, yf, zf = _to_xyz(x3f[:, 0], x3f[:, 1], x3f[:, 2]) # subsample for 3D scatter sk = 8 for ax, x, y, z, label in [ (axes_box, xb[::sk], yb[::sk], zb[::sk], "3DBox"), (axes_from2d, xf[::sk], yf[::sk], zf[::sk], "3DFrom2D"), ]: ax.scatter(x, y, z, s=0.3, alpha=0.3, rasterized=True) ax.set_title(f"{name}\n{label}", fontsize=8) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("Z") # (R, Z) projections for ax, pts, label in [ (axes_rz_box, x3b, "3DBox"), (axes_rz_from2d, x3f, "3DFrom2D"), ]: ax.scatter(pts[:, 0], pts[:, 1], s=0.5, alpha=0.3, rasterized=True) R_c = np.append(R_wall, R_wall[0]) Z_c = np.append(Z_wall, Z_wall[0]) ax.plot(R_c, Z_c, "k-", lw=1.0) ax.set_title(f"{name}\n{label} — (R,Z)", fontsize=8) ax.set_xlabel("R") ax.set_ylabel("Z") ax.set_aspect("equal") return key def plot_3d(walls: dict, key: jax.Array) -> jax.Array: """Plot 3D toroidal sampling — one figure per wall.""" for name, (R_wall, Z_wall) in walls.items(): fig = plt.figure(figsize=(16, 8)) fig.suptitle(f"Sampling 3D — {name}", fontsize=12) ax_box = fig.add_subplot(2, 2, 1, projection="3d") ax_from2d = fig.add_subplot(2, 2, 2, projection="3d") ax_rz_box = fig.add_subplot(2, 2, 3) ax_rz_from2d = fig.add_subplot(2, 2, 4) key = plot_3d_one( name, R_wall, Z_wall, key, fig, ax_box, ax_from2d, ax_rz_box, ax_rz_from2d, ) plt.tight_layout() return key # ── Comparison: 3DBox vs 3DFrom2D radial distribution ───────────────────────── def plot_radial_comparison(walls: dict, key: jax.Array) -> jax.Array: """Compare radial (R) distribution between 3DBox and 3DFrom2D.""" n = len(walls) fig, axes = plt.subplots(1, n, figsize=(5 * n, 4)) fig.suptitle("Distribution radiale R : 3DBox vs 3DFrom2D", fontsize=12) if n == 1: axes = [axes] for ax, (name, (R_wall, Z_wall)) in zip(axes, walls.items()): s3b = TokamakSampler3DBox(R_wall, Z_wall, oversample=5) key, d3b = s3b.sample(key, n=N_INT) s3f = TokamakSampler3DFrom2D(R_wall, Z_wall, oversample=5) key, d3f = s3f.sample(key, n=N_INT) R_box = np.array(d3b["interior"][0][:, 0]) R_from2d = np.array(d3f["interior"][0][:, 0]) bins = np.linspace(R_wall.min(), R_wall.max(), 50) ax.hist(R_box, bins=bins, alpha=0.6, label="3DBox (∝ dR dZ dφ)") ax.hist(R_from2d, bins=bins, alpha=0.6, label="3DFrom2D (∝ R dR dZ dφ)") # theoretical toroidal density ∝ R R_mid = 0.5 * (bins[:-1] + bins[1:]) density = R_mid / R_mid.sum() ax.plot(R_mid, density * N_INT, "k--", lw=1.5, label="théorique ∝ R") ax.set_title(name, fontsize=9) ax.set_xlabel("R [m]") ax.set_ylabel("nombre de points") if ax is axes[0]: ax.legend(fontsize=8) plt.tight_layout() return key # ── Main ─────────────────────────────────────────────────────────────────────── if __name__ == "__main__": key = KEY print("=== Plot 2D ===") key = plot_2d(walls, key) print("=== Plot 3D ===") key = plot_3d(walls, key) print("=== Comparaison distribution radiale ===") key = plot_radial_comparison(walls, key) plt.show()