r"""Solves the advection equation in 1D using a discrete PINN. .. math:: \partial_t u + c \partial_{x} u & = f in \Omega \times (0, T) \\ \partial_x u & = g on \partial \Omega \times (0, T) \\ u & = u_0 on \Omega \times {0} where :math:`u: \partial \Omega \times (0, T) \to \mathbb{R}` is the unknown function, :math:`\Omega \subset \mathbb{R}` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. The equation is solved on a segment domain; weak (homogeneous Dirichlet) boundary conditions and natural gradient preconditioning are used. We solve the equation using: - an explicit method (RK4) - an implicit method (Pareschi-Russo) - an IMEX method (ARS222), where c is split into a slow and a fast component, the former treated explicitly and the latter implicitly. """ import time import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.linear_approximation.time_integrators.butcher_tableau import ( build_ars222_tableau, build_pareschi_russo_tableau, build_rk2_tableau, ) from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.discrete_pinns import ( DiscretePINN, ) from scimba_jax.physical_models.temporal_pde.advection import SemiDiscreteAdvectionND from scimba_jax.plots.plots_nd import plot_abstract_approx_spaces N_COLLOC = 1000 N_BC_COLLOC = 200 N_EPOCHS_INIT = 100 N_EPOCHS = 10 DOM_X = Segment1D((0.0, 1.0), is_main_domain=True) DOM_T = (0.0, 0.5) SAMPLER = TensorizedSampler([DomainSampler(DOM_X)], model_type="x") SLOW_VELOCITY = 0.1 FAST_VELOCITY = 0.9 VELOCITY = SLOW_VELOCITY + FAST_VELOCITY def exact_sol(t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: return jnp.exp(-((x - (0.25 + VELOCITY * t)) ** 2) * 150) def f_init(x: jnp.ndarray) -> jnp.ndarray: t = jnp.zeros_like(x) return exact_sol(t, x) # create the discrete PINN parameters params = { "rk2": { "tableau": build_rk2_tableau(), "explicit_pde": SemiDiscreteAdvectionND(DOM_X, DOM_T, VELOCITY), "implicit_pde": None, }, "pareschi_russo": { "tableau": build_pareschi_russo_tableau(), "explicit_pde": None, "implicit_pde": SemiDiscreteAdvectionND(DOM_X, DOM_T, VELOCITY), }, "ars222": { "tableau": build_ars222_tableau(), "explicit_pde": SemiDiscreteAdvectionND(DOM_X, DOM_T, SLOW_VELOCITY), "implicit_pde": SemiDiscreteAdvectionND(DOM_X, DOM_T, FAST_VELOCITY), }, } nt = 50 in_size = 1 out_size = 1 space = None spaces = [] for method, param in params.items(): butcher_tableau = param["tableau"] explicit_pde = param["explicit_pde"] implicit_pde = param["implicit_pde"] key = jax.random.PRNGKey(0) discrete_pinn = DiscretePINN( DOM_X, DOM_T, SAMPLER, out_size, nt, butcher_tableau, explicit_pde, implicit_pde ) if space is None: # train the initial condition only for the first tableau nn = MLP(in_size=in_size, out_size=out_size, hidden_sizes=[12] * 2, key=key) space = ApproximationSpace({"x": 1}, [(nn, "scalar", None)], model_type="x") print("Initializing the discrete PINN...") key, space = discrete_pinn.initialize( key, space, f_init, N_EPOCHS_INIT, N_COLLOC, n_bc_colloc=N_BC_COLLOC ) print("Initializing the discrete PINN... Done\n") plot_abstract_approx_spaces( [space], DOM_X, solution=f_init, error=f_init, title="initial condition" ) print(f"\nSolving with the {method} method...") start = time.perf_counter() key, space_ = discrete_pinn.solve( key, space, N_EPOCHS, N_COLLOC, n_bc_colloc=N_BC_COLLOC ) end = time.perf_counter() print(f"\nSolving with the {method} method... Done in {end - start:.2f} seconds\n") spaces.append(space_) # %% plot_abstract_approx_spaces( spaces, DOM_X, solution=lambda x: exact_sol(jnp.ones_like(x) * DOM_T[-1], x), error=lambda x: exact_sol(jnp.ones_like(x) * DOM_T[-1], x), title=f"solution at final time t={DOM_T[-1]}", titles=[f"solution with {method} method" for method in params.keys()], ) plt.show() # %%