r"""Solves the linearized Euler equations in 1D using a discrete PINN. .. math:: \partial_t p + \partial_x u & = 0 in \Omega \times (0, T) \\ \partial_t u + \partial_x p & = 0 in \Omega \times (0, T) where :math:`(p, u): \Omega \times (0, T) \to \mathbb{R}^2` is the unknown field, :math:`\Omega \subset \mathbb{R}` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. Periodic boundary conditions are prescribed through a periodic embedding in the neural network architecture. The equation is discretized in time with Runge-Kutta one-step methods and solved stage-by-stage with PINNs. """ import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.linear_approximation.time_integrators.butcher_tableau import ( build_dirk_3_4_tableau, build_rk4_38_tableau, ) from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.discrete_pinns import ( DiscretePINN, ) from scimba_jax.physical_models.temporal_pde.linearized_euler import ( SemiDiscreteLinearizedEuler, ) from scimba_jax.plots.plots_nd import plot_abstract_approx_spaces N_COLLOC = 1200 N_EPOCHS_INIT = 120 N_EPOCHS = 20 DOM_X = Segment1D((0.0, 2.0), is_main_domain=True) DOM_T = (0.0, 1.75) SAMPLER = TensorizedSampler([DomainSampler(DOM_X)], model_type="x") def exact_sol(t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: sigma = 0.02 coeff = 1.0 / jnp.sqrt(4.0 * jnp.pi * sigma) x2_plus = ((x - t) % 2.0 - 1.0) ** 2 x2_minus = ((x + t) % 2.0 - 1.0) ** 2 p_plus_u = coeff * jnp.exp(-x2_plus / (4.0 * sigma)) p_minus_u = coeff * jnp.exp(-x2_minus / (4.0 * sigma)) pressure = 0.5 * (p_plus_u + p_minus_u) velocity = 0.5 * (p_plus_u - p_minus_u) return jnp.concatenate((pressure, velocity), axis=-1) def f_init(x: jnp.ndarray) -> jnp.ndarray: t = jnp.zeros_like(x) return exact_sol(t, x) pde = SemiDiscreteLinearizedEuler(main_domain=DOM_X, time_domain=DOM_T, bc="strong") params = { "RK4 (3/8 rule)": { "tableau": build_rk4_38_tableau(), "explicit_pde": pde, "implicit_pde": None, }, "DIRK (3,4)": { "tableau": build_dirk_3_4_tableau(), "explicit_pde": None, "implicit_pde": pde, }, } nt = 20 in_size = 1 out_size = 2 space = None spaces = [] errors_over_time = [] for method, param in params.items(): butcher_tableau = param["tableau"] explicit_pde = param["explicit_pde"] implicit_pde = param["implicit_pde"] key = jax.random.PRNGKey(0) discrete_pinn = DiscretePINN( DOM_X, DOM_T, SAMPLER, out_size, nt, butcher_tableau, explicit_pde, implicit_pde, exact_solution=exact_sol, ) if space is None: nn = MLP( in_size=in_size, out_size=out_size, hidden_sizes=[20, 20], key=key, embedding="periodic", periods=[2.0], ) space = ApproximationSpace({"x": 1}, [(nn, "vec", 2)], model_type="x") print("Initializing the discrete PINN...") key, space = discrete_pinn.initialize( key, space, f_init, N_EPOCHS_INIT, N_COLLOC ) print("Initializing the discrete PINN... Done\n") plot_abstract_approx_spaces( [space], DOM_X, solution=f_init, error=f_init, title="initial condition" ) print(f"\nSolving with the {method} method...") key, space_ = discrete_pinn.solve(key, space, N_EPOCHS, N_COLLOC) spaces.append(space_) errors_over_time.append(discrete_pinn.errors_over_time) print(f"\nSolving with the {method} method... Done\n") fig, ax = plt.subplots(1, 2, figsize=(12, 5)) error_type = ["L2", "Linf"] for i in range(2): for j, (method_name, _) in enumerate(params.items()): ax[i].semilogy( jnp.linspace(DOM_T[0], DOM_T[1], nt + 1), errors_over_time[j][:, i], label=method_name, ) ax[i].set_xlabel("time") ax[i].set_ylabel(f"relative {error_type[i]} error") ax[i].set_title(f"Relative {error_type[i]} error vs time") ax[i].legend() ax[i].grid() for j, (method_name, _) in enumerate(params.items()): title = f"solution at final time t={DOM_T[-1]} with {method_name}" plot_abstract_approx_spaces( [spaces[j]], DOM_X, components=[{"pressure": 0}, {"velocity": 1}], solution=lambda x: exact_sol(jnp.ones_like(x) * DOM_T[-1], x), error=lambda x: exact_sol(jnp.ones_like(x) * DOM_T[-1], x), title=title, ) plt.show()