r"""Solves the 2D compressible Euler equations with a discrete PINN. Test case: Shu's isentropic vortex (steady exact solution). The conserved variables are W = (rho, m_x, m_y, E), and the PDE is .. math:: \partial_t W + \nabla \cdot F(W) = 0, with flux .. math:: F^j(W) = \left[ m_j,\; m_x m_j / \rho + p\,\delta_{xj},\; m_y m_j / \rho + p\,\delta_{yj},\; m_j / \rho\,(E + p) \right] and :math:`p = (\gamma-1)(E - |m|^2 / (2\rho))`, :math:`\gamma = 1.4`. **Isentropic vortex (Shu 1998).** With background velocity :math:`(u_\infty, v_\infty) = (0, 0)`, the perturbation fields centered at :math:`(x_0, y_0) = (5, 5)` on :math:`[0,10]^2` are .. math:: \delta T &= -\frac{(\gamma-1)\varepsilon^2}{8\gamma\pi^2} e^{1-r^2}, \\ (\delta u, \delta v) &= \frac{\varepsilon}{2\pi} e^{(1-r^2)/2} \bigl(-\eta,\; \xi\bigr), with :math:`\xi = x - x_0`, :math:`\eta = y - y_0`, :math:`r^2 = \xi^2+\eta^2`. Setting :math:`T = 1 + \delta T`, :math:`\rho = T^{1/(\gamma-1)}` and using :math:`p = \rho^\gamma` gives an exact **steady** solution of the full 2D compressible Euler equations (centripetal pressure balance, divergence-free velocity). The exact solution is periodic (exponential decay makes boundary values negligible) and time-independent, so the error at any time :math:`t > 0` measures how well the discrete PINN preserves the steady state. """ import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.linear_approximation.time_integrators.butcher_tableau import ( build_explicit_euler_tableau, build_implicit_euler_tableau, ) from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.discrete_pinns import ( DiscretePINN, ) from scimba_jax.physical_models.temporal_pde.euler import SemiDiscreteEuler from scimba_jax.plots.plots_nd import plot_abstract_approx_spaces # ── problem parameters ──────────────────────────────────────────────────────── GAMMA = 1.4 EPSILON = 5.0 # vortex strength X0, Y0 = 5.0, 5.0 # vortex center L = 10.0 # domain side length # ── numerical parameters ───────────────────────────────────────────────────── N_COLLOC = 2000 N_EPOCHS_INIT = 200 N_EPOCHS = 15 NT = 20 DOM_X = Square2D([(0.0, L), (0.0, L)], is_main_domain=True) DOM_T = (0.0, 1.0) SAMPLER = TensorizedSampler([DomainSampler(DOM_X)], model_type="x") # ── exact steady solution ───────────────────────────────────────────────────── def exact_sol(t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: """Shu's stationary isentropic vortex — exact for all t (steady state). Args: t: time, shape (..., 1). Not used (solution is steady). x: spatial coordinates, shape (..., 2). Returns: W = (rho, m_x, m_y, E) of shape (..., 4). """ xi = x[..., 0:1] - X0 # x - x_0 eta = x[..., 1:2] - Y0 # y - y_0 r2 = xi**2 + eta**2 # Temperature perturbation (isentropic) dT = -(GAMMA - 1.0) * EPSILON**2 / (8.0 * GAMMA * jnp.pi**2) * jnp.exp(1.0 - r2) # Azimuthal velocity perturbation (divergence-free, u_inf = v_inf = 0) factor = EPSILON / (2.0 * jnp.pi) * jnp.exp(0.5 * (1.0 - r2)) u = -factor * eta v = +factor * xi T = 1.0 + dT rho = T ** (1.0 / (GAMMA - 1.0)) # from p = rho^gamma and T = rho^(gamma-1) p = rho**GAMMA m_x = rho * u m_y = rho * v E = p / (GAMMA - 1.0) + 0.5 * rho * (u**2 + v**2) return jnp.concatenate([rho, m_x, m_y, E], axis=-1) def f_init(x: jnp.ndarray) -> jnp.ndarray: """Initial condition W(t=0, x).""" return exact_sol(jnp.zeros_like(x[..., :1]), x) # ── PDE ─────────────────────────────────────────────────────────────────────── pde = SemiDiscreteEuler(main_domain=DOM_X, time_domain=DOM_T, gamma=GAMMA, bc="strong") params = { "explicit_euler": { "tableau": build_explicit_euler_tableau(), "explicit_pde": pde, "implicit_pde": None, }, "implicit_euler": { "tableau": build_implicit_euler_tableau(), "explicit_pde": None, "implicit_pde": pde, }, } # ── main loop ───────────────────────────────────────────────────────────────── in_size = 2 # (x, y) out_size = 4 # (rho, m_x, m_y, E) space = None spaces = [] errors_over_time = [] for method, param in params.items(): butcher_tableau = param["tableau"] explicit_pde = param["explicit_pde"] implicit_pde = param["implicit_pde"] key = jax.random.PRNGKey(0) discrete_pinn = DiscretePINN( DOM_X, DOM_T, SAMPLER, out_size, NT, butcher_tableau, explicit_pde, implicit_pde, exact_solution=exact_sol, ) if space is None: nn = MLP( in_size=in_size, out_size=out_size, hidden_sizes=[12] * 3, key=key, embedding="periodic", periods=[L, L], ) space = ApproximationSpace({"x": 2}, [(nn, "vec", out_size)], model_type="x") print("Initializing the discrete PINN...") key, space = discrete_pinn.initialize( key, space, f_init, N_EPOCHS_INIT, N_COLLOC ) print("Initializing the discrete PINN... Done\n") plot_abstract_approx_spaces( [space], DOM_X, solution=f_init, error=f_init, title="Isentropic vortex — initial condition", ) plt.show() print(f"\nSolving with the {method} method...") key, space_ = discrete_pinn.solve(key, space, N_EPOCHS, N_COLLOC) spaces.append(space_) errors_over_time.append(discrete_pinn.errors_over_time) print(f"\nSolving with the {method} method... Done\n") # ── error plots ─────────────────────────────────────────────────────────────── fig, ax = plt.subplots(1, 2, figsize=(12, 5)) error_type = ["L2", "Linf"] for i in range(2): for j, (method_name, _) in enumerate(params.items()): ax[i].semilogy( jnp.linspace(DOM_T[0], DOM_T[1], NT + 1), errors_over_time[j][:, i], label=method_name, ) ax[i].set_xlabel("time") ax[i].set_ylabel(f"relative {error_type[i]} error") ax[i].set_title(f"Relative {error_type[i]} error vs time") ax[i].legend() ax[i].grid() t_final = DOM_T[-1] plot_abstract_approx_spaces( spaces, DOM_X, solution=lambda x: exact_sol(jnp.ones((x.shape[0], 1)) * t_final, x), error=lambda x: exact_sol(jnp.ones((x.shape[0], 1)) * t_final, x), title=f"Euler isentropic vortex at t = {t_final}", titles=[f"{method_name}" for method_name in params.keys()], ) plt.show()