r"""Solves the heat equation in nD using a discrete PINN. .. math:: \partial_t u - \Delta u & = f in \Omega \times (0, T) \\ \partial_x u & = g on \partial \Omega \times (0, T) \\ u & = u_0 on \Omega \times {0} where :math:`u: \partial \Omega \times (0, T) \to \mathbb{R}` is the unknown function, :math:`\Omega \subset \mathbb{R}^d` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. The equation is solved on a hypercube domain; strong (homogeneous Dirichlet) boundary conditions and natural gradient preconditioning are used. The equation is discretized in time using a RK4 and a DIRK(4,5) method, and the resulting spatial problems are solved using a PINN. """ import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_nd import HypercubeND from scimba_jax.linear_approximation.time_integrators.butcher_tableau import ( build_dirk_4_5_tableau, ) from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.discrete_pinns import ( DiscretePINN, ) from scimba_jax.physical_models.temporal_pde.heat_equations import SemiDiscreteHeatND from scimba_jax.plots.plots_nd import plot_abstract_approx_spaces DIMENSION = 4 N_COLLOC = 1500 * DIMENSION N_EPOCHS_INIT = 250 + 50 * DIMENSION N_EPOCHS = 5 X_MIN, X_MAX = -1.0, 1.0 DOM_X = HypercubeND([(X_MIN, X_MAX)] * DIMENSION, is_main_domain=True) SAMPLER = TensorizedSampler([DomainSampler(DOM_X)], bc=False) FINAL_AMPLITUDE = 0.2 FINAL_TIME = -jnp.log(FINAL_AMPLITUDE).item() / (DIMENSION * jnp.pi**2) DOM_T = (0.0, FINAL_TIME) def exact_sol(t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: sines = jnp.prod(jnp.sin(jnp.pi * x), axis=-1, keepdims=True) exp = jnp.exp(-t * DIMENSION * jnp.pi**2) return exp * sines def f_init(x: jnp.ndarray) -> jnp.ndarray: return exact_sol(0, x) def post_processing(approx: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: phi = jnp.prod((x - X_MIN) * (X_MAX - x), axis=-1, keepdims=True) return approx * phi # create the model in_size = DIMENSION out_size = 1 nt = 10 pde = SemiDiscreteHeatND(main_domain=DOM_X, time_domain=DOM_T, bc="strong") discrete_pinn = DiscretePINN( DOM_X, DOM_T, SAMPLER, out_size, nt, build_dirk_4_5_tableau(), None, pde ) key = jax.random.PRNGKey(0) nn = MLP(in_size=in_size, out_size=out_size, hidden_sizes=[12] * DIMENSION, key=key) space = ApproximationSpace( {"x": in_size}, [(nn, "scalar", None)], model_type="x", post_processing=post_processing, ) key, space = discrete_pinn.initialize( key, space, f_init, N_EPOCHS_INIT, N_COLLOC, matrix_regularization=1e-8 ) if DIMENSION <= 2: plot_abstract_approx_spaces( [space], DOM_X, solution=f_init, error=f_init, title="initial condition" ) # create and train the discrete pinn for the PDE key, new_space = discrete_pinn.solve(key, space, N_EPOCHS, N_COLLOC) # %% key, l2, linf = discrete_pinn.compute_relative_error( key, space, 0, exact_solution=exact_sol ) print( f"initial condition: relative L2 error = {l2:.2e}, relative Linf error = {linf:.2e}" ) key, l2, linf = discrete_pinn.compute_relative_error( key, new_space, FINAL_TIME, exact_solution=exact_sol ) print( f"DIRK(4,5) method: relative L2 error = {l2:.2e}, relative Linf error = {linf:.2e}" ) if DIMENSION <= 2: plot_abstract_approx_spaces( [new_space], DOM_X, solution=lambda x: exact_sol(jnp.ones((x.shape[0], 1)) * FINAL_TIME, x), error=lambda x: exact_sol(jnp.ones((x.shape[0], 1)) * FINAL_TIME, x), title=f"solution at final time t={FINAL_TIME}", titles=["solution with DIRK(4,5) method"], ) plt.show() # %%