r"""Solves the 1D isentropic Euler equations using a discrete PINN. The system in conserved variables :math:`W = (\rho, q)` where :math:`q = \rho u`: .. math:: \partial_t \rho + \partial_x q & = 0 \\ \partial_t q + \partial_x \!\left(\frac{q^2}{\rho} + \rho^\gamma\right) & = 0 with :math:`\gamma = 2` and Dirichlet boundary conditions on :math:`[0, 1]`. """ import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.linear_approximation.time_integrators.butcher_tableau import ( build_implicit_euler_tableau, ) from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.discrete_pinns import ( DiscretePINN, ) from scimba_jax.physical_models.temporal_pde.isentropic_euler import ( SemiDiscreteIsentropicEuler, ) from scimba_jax.plots.plots_nd import plot_abstract_approx_spaces # ── problem parameters ─────────────────────────────────────────────────────── GAMMA = 2.0 # adiabatic exponent # ── numerical parameters ───────────────────────────────────────────────────── N_COLLOC = 1200 N_EPOCHS_INIT = 150 N_EPOCHS = 5 NT = 200 DOM_X = Segment1D((0.0, 1.0), is_main_domain=True) DOM_T = (0.0, 0.1) SAMPLER = TensorizedSampler([DomainSampler(DOM_X)], model_type="x") # ── initial condition and pre-processing ───────────────────────────────────── def f_init(x: jnp.ndarray, regularization: bool = True) -> jnp.ndarray: """Initial condition.""" if regularization: rho = 1.5 - jnp.tanh(50 * (x - 0.5)) / 2 else: rho = jnp.where(x < 0.5, 2.0, 1.0) q = jnp.zeros_like(x) return jnp.concatenate([rho, q], axis=-1) def post_processing(approx: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: rho = approx[..., 0:1] q = approx[..., 1:2] rho_ = rho * x * (1.0 - x) + (2.0 - x) q = q * x * (1.0 - x) return jnp.concatenate([rho_, q], axis=-1) # ── PDE and Runge-Kutta schemes ────────────────────────────────────────────── pde = SemiDiscreteIsentropicEuler( main_domain=DOM_X, time_domain=DOM_T, gamma=GAMMA, bc="strong" ) params = { "Implicit Euler": { "tableau": build_implicit_euler_tableau(), "explicit_pde": None, "implicit_pde": pde, }, } # ── main loop ──────────────────────────────────────────────────────────────── in_size = 1 out_size = 2 # (ρ, q) space = None spaces = [] errors_over_time = [] for method, param in params.items(): butcher_tableau = param["tableau"] explicit_pde = param["explicit_pde"] implicit_pde = param["implicit_pde"] key = jax.random.PRNGKey(0) discrete_pinn = DiscretePINN( DOM_X, DOM_T, SAMPLER, out_size, NT, butcher_tableau, explicit_pde, implicit_pde, ) if space is None: nn = MLP( in_size=in_size, out_size=out_size, hidden_sizes=[12, 12], key=key, ) space = ApproximationSpace( {"x": 1}, [(nn, "vec", 2)], model_type="x", post_processing=post_processing ) print("Initializing the discrete PINN...") key, space = discrete_pinn.initialize( key, space, f_init, N_EPOCHS_INIT, N_COLLOC ) print("Initializing the discrete PINN... Done\n") plot_abstract_approx_spaces( [space], DOM_X, solution=f_init, error=f_init, title="initial condition" ) print(f"\nSolving with the {method} method...") key, space_ = discrete_pinn.solve(key, space, N_EPOCHS, N_COLLOC) spaces.append(space_) errors_over_time.append(discrete_pinn.errors_over_time) print(f"\nSolving with the {method} method... Done\n") # ── error plots ─────────────────────────────────────────────────────────────── plot_abstract_approx_spaces( spaces, DOM_X, title=f"isentropic Euler solution at t = {DOM_T[-1]}", titles=[f"{method_name}" for method_name in params.keys()], ) plt.show()