r"""Solves the 1D full Euler equations with a discrete PINN on Sod's test case. Conserved variables are W = (rho, m, E), with .. math:: p = (\gamma - 1)\left(E - \frac{1}{2}\frac{m^2}{\rho}\right). The semi-discrete PDE is .. math:: \partial_t W + \partial_x F(W) = 0, where .. math:: F(W) =\left[m,\; \frac{m^2}{\rho} + p,\; \frac{m}{\rho}(E + p)\right]. Sod's Riemann problem is used with a smoothed discontinuity around x=0.5. """ import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.linear_approximation.time_integrators.butcher_tableau import ( build_implicit_euler_tableau, ) from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.discrete_pinns import ( DiscretePINN, ) from scimba_jax.physical_models.temporal_pde.euler import SemiDiscreteEuler from scimba_jax.plots.plots_nd import plot_abstract_approx_spaces # Problem parameters (Sod tube) GAMMA = 1.4 X_INTERFACE = 0.5 SMOOTHNESS = 80.0 # RHO_L, U_L, P_L = 1.0, -1.0, 1.0 # RHO_R, U_R, P_R = 1.0, 1.0, 1.0 RHO_L, U_L, P_L = 1.0, 0.0, 1.0 RHO_R, U_R, P_R = 0.125, 0.0, 0.1 Q_L, E_L = RHO_L * U_L, P_L / (GAMMA - 1) + 0.5 * U_L**2 * RHO_L Q_R, E_R = RHO_R * U_R, P_R / (GAMMA - 1) + 0.5 * U_R**2 * RHO_R # Numerical parameters N_COLLOC = 1200 N_EPOCHS_INIT = 200 N_EPOCHS = 10 NT = 100 DOM_X = Segment1D((0.0, 1.0), is_main_domain=True) DOM_T = (0.0, 0.2) SAMPLER = TensorizedSampler([DomainSampler(DOM_X)], model_type="x") def smooth_riemann(left: float, right: float, x: jnp.ndarray) -> jnp.ndarray: """Smooth approximation of a left/right Riemann state.""" tanh = jnp.tanh(SMOOTHNESS * (x - X_INTERFACE)) return 0.5 * (left + right) - 0.5 * (left - right) * tanh def f_init(x: jnp.ndarray) -> jnp.ndarray: """Sod initial condition in conservative variables (rho, u, p).""" rho = smooth_riemann(RHO_L, RHO_R, x) u = smooth_riemann(U_L, U_R, x) p = smooth_riemann(P_L, P_R, x) q = rho * u e = p / (GAMMA - 1) + 0.5 * q**2 / rho return jnp.concatenate([rho, q, e], axis=-1) def post_processing(approx: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: """Impose left/right boundary states strongly through the network output. This uses a linear interpolation between boundary states plus a trainable correction that vanishes at x=0 and x=1. """ phi = x * (1.0 - x) rho_base = (1.0 - x) * RHO_L + x * RHO_R q_base = (1.0 - x) * Q_L + x * Q_R e_base = (1.0 - x) * E_L + x * E_R rho = rho_base + phi * approx[..., 0:1] q = q_base + phi * approx[..., 1:2] e = e_base + phi * approx[..., 2:3] return jnp.concatenate([rho, q, e], axis=-1) pde = SemiDiscreteEuler(main_domain=DOM_X, time_domain=DOM_T, gamma=GAMMA, bc="strong") params = { "Implicit Euler": { "tableau": build_implicit_euler_tableau(), "explicit_pde": None, "implicit_pde": pde, }, } in_size = 1 out_size = 3 # (rho, q, e) space = None spaces = [] for method, param in params.items(): butcher_tableau = param["tableau"] explicit_pde = param["explicit_pde"] implicit_pde = param["implicit_pde"] key = jax.random.PRNGKey(0) discrete_pinn = DiscretePINN( DOM_X, DOM_T, SAMPLER, out_size, NT, butcher_tableau, explicit_pde, implicit_pde, ) if space is None: nn = MLP( in_size=in_size, out_size=out_size, hidden_sizes=[12] * 3, key=key, ) space = ApproximationSpace( {"x": 1}, [(nn, "vec", 3)], model_type="x", post_processing=post_processing, ) print("Initializing the discrete PINN...") key, space = discrete_pinn.initialize( key, space, f_init, N_EPOCHS_INIT, N_COLLOC ) print("Initializing the discrete PINN... Done\n") plot_abstract_approx_spaces( [space], DOM_X, components=[{"density": 0}, {"momentum": 1}, {"energy": 2}], solution=f_init, error=f_init, title="Sod initial condition", titles=["rho, q, e at t=0"], ) print(f"\nSolving with the {method} method...") key, space_ = discrete_pinn.solve(key, space, N_EPOCHS, N_COLLOC) spaces.append(space_) print(f"\nSolving with the {method} method... Done\n") # %% def construct_velocity(*vars): w = vars[0] rho, q = w.component(0), w.component(1) return q / (rho + 1e-8) def construct_pressure(*vars): w = vars[0] rho, q, e = w.component(0), w.component(1), w.component(2) u = q / (rho + 1e-8) return (GAMMA - 1) * (e - 0.5 * rho * u**2) additional_scalar_functions = { "$u$": construct_velocity, "$p$": construct_pressure, } for space in spaces: plot_abstract_approx_spaces( [space], DOM_X, components=[{"density": 0}, {"momentum": 1}, {"energy": 2}], title=f"Euler (Sod) solution at t = {DOM_T[-1]}", titles=[f"{method_name}" for method_name in params.keys()], additional_scalar_functions=additional_scalar_functions, ) plt.show() # %%