# %% r"""Solves the heat equation in 1D using a PINN. .. math:: \partial_t u - \partial_{xx} u & = f in \Omega \times (0, T) \\ \partial_x u & = g on \partial \Omega \times (0, T) \\ u & = u_0 on \Omega \times {0} where :math:`u: \partial \Omega \times (0, T) \to \mathbb{R}` is the unknown function, :math:`\Omega \subset \mathbb{R}` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. Neumann boundary conditions are prescribed, and the initial condition is a sine. The equation is solved on a segment domain; strong boundary and initial conditions are used. Two training strategies are compared: PINNs with energy natural gradient preconditioning and PINNS with SS-BFGS optimization. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_nd import HypercubeND from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.model_class.funcparam_vectorial import ( ParamVecFunction, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.temporal_pde.heat_equations import HeatND from scimba_jax.plots.plots_nd import ( plot_abstract_approx_space, ) DIMENSION = 4 N_COLLOC = 1500 * (DIMENSION + 1) N_EPOCHS = 50 * 5 + 250 + 50 * (DIMENSION + 1) X_MIN, X_MAX = -1.0, 1.0 DOM_X = HypercubeND([(X_MIN, X_MAX)] * DIMENSION, is_main_domain=True) FINAL_AMPLITUDE = 0.2 FINAL_TIME = (-jnp.log(FINAL_AMPLITUDE) / (DIMENSION * jnp.pi**2)).item() DOM_T = (0.0, FINAL_TIME) SAMPLER = TensorizedSampler( [UniformTimeSampler(DOM_T), DomainSampler(DOM_X)], model_type="t_x", bc=False, ic=False, ) def exact_sol(t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: sines = jnp.prod(jnp.sin(jnp.pi * x), axis=-1, keepdims=True) exp = jnp.exp(-t * DIMENSION * jnp.pi**2) return exp * sines def f_init(x: jnp.ndarray) -> jnp.ndarray: return exact_sol(0, x) def post_processing(approx: jnp.ndarray, t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: phi = jnp.prod((x - X_MIN) * (X_MAX - x), axis=-1, keepdims=True) return f_init(x) + t * approx * phi # %% create and train the PINN # create the model model = HeatND( main_domain=DOM_X, time_domain=DOM_T, bc="strong", ic="strong", f_ic_rhs=lambda *args: f_init(*args), ) # create the approximation space key = jax.random.PRNGKey(0) nn = MLP( in_size=DIMENSION + 1, out_size=1, hidden_sizes=[12] * (DIMENSION + 1), key=key ) space = ApproximationSpace( {"x": DIMENSION, "t": 1}, [(nn, "scalar", None)], model_type="t_x", post_processing=post_processing, ) pinn = Projector(model, space, SAMPLER, matrix_regulazition=1e-8) start = timeit.default_timer() key, pinn = pinn.project(key, space, N_EPOCHS, N_COLLOC) end = timeit.default_timer() # %% exploit the results def get_t_eval(t, x): return jnp.ones((x.shape[0], 1)) * t def compute_relative_error(space, key, t=FINAL_TIME, n_error=150_000): key, subkey = jax.random.split(key) key, sample = SAMPLER.sample(subkey, n_error) _, x = sample["interior"] t = get_t_eval(t, x) variables = space.create_variables() all_together = ParamVecFunction.cat(variables) batched_func = all_together.vmap_on_physical_variables() u_pred = jax.device_get(batched_func(space, t, x)) u_exact = exact_sol(t, x) error = jnp.abs(u_pred - u_exact) relative_l2 = jnp.sqrt(jnp.mean(error**2)) / jnp.sqrt(jnp.mean(u_exact**2)) relative_linf = jnp.max(error) / jnp.max(jnp.abs(u_exact)) return relative_l2, relative_linf, key l2, linf, key = compute_relative_error(pinn.space, key) print(f"PINN: relative L2 error = {l2:.2e}, relative Linf error = {linf:.2e}") if DIMENSION <= 2: plot_abstract_approx_space( pinn.space, DOM_X, time_domain=DOM_T, time_values=(DOM_T[0], DOM_T[1] / 2, DOM_T[1]), loss=pinn.losses, solution=exact_sol, error=exact_sol, title="learning sol of 1D heat equation with TemporalPinns", ) plt.show() # %%