r"""Solves a 4D Poisson PDE with Dirichlet boundary conditions using PINNs. .. math:: -\Delta u & = f \quad \text{in } \Omega where :math:`x = (x_1, x_2, x_3, x_4) \in \Omega = (-1, 1)^4` and :math:`f` is chosen such that the exact solution is: .. math:: u(x) = \prod_{i=1}^{4} \sin(\pi x_i) which gives :math:`f(x) = 4\pi^2 \prod_{i=1}^{4} \sin(\pi x_i)` and homogeneous Dirichlet boundary conditions. The boundary conditions are enforced strongly via a post-processing that vanishes on :math:`\partial \Omega`. The neural network is a simple MLP (Multilayer Perceptron). """ # %% import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_nd import HypercubeND from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.elliptic_pde.laplacians import LaplacianDirichletND N_COLLOC = 10000 N_BC_COLLOC = 10000 N_EPOCHS = 200 key = jax.random.PRNGKey(0) dim = 4 domain_x = [(-1.0, 1.0)] * dim def f_rhs(x: jnp.ndarray) -> jnp.ndarray: """RHS such that u = prod_i sin(pi * x_i).""" return dim * jnp.pi**2 * jnp.prod(jnp.sin(jnp.pi * x), axis=-1, keepdims=True) def exact_sol(x: jnp.ndarray) -> jnp.ndarray: """Exact solution u = prod_i sin(pi * x_i), batched.""" return jnp.prod(jnp.sin(jnp.pi * x), axis=-1, keepdims=True) def post_processing(approx: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: """Enforce homogeneous Dirichlet BC strongly: multiply by prod_i (1 - x_i^2).""" return approx * jnp.prod(1.0 - x**2) dx = HypercubeND(domain_x, is_main_domain=True) sampler = TensorizedSampler([DomainSampler(dx)], bc=False) # %% print( "@@@@@@@@@@@@@@@ create a PINN with strong BC (4D Laplacian) @@@@@@@@@@@@@@@@@@@@@" ) nn = MLP(in_size=dim, out_size=1, hidden_sizes=[24, 24, 24], key=key) space = ApproximationSpace( {"x": dim}, [(nn, "scalar", None)], model_type="x", post_processing=post_processing, ) model = LaplacianDirichletND(dx, lambda *args: f_rhs(*args), bc="strong") pinn = Projector(model, space, sampler) loss = pinn.evaluate_loss(space, sampler.sample(key, N_COLLOC)[1]) print("initial loss: ", loss) # %% print("@@@@@@@@@@@@@@@ train with Energy Natural Gradient @@@@@@@@@@@@@@@@@@@@@") start = timeit.default_timer() key, pinn = pinn.project(key, space, N_EPOCHS, N_COLLOC) end = timeit.default_timer() print("best loss: ", pinn.best_loss) print(f"time for {N_EPOCHS} epochs: {end - start:.2f}s") # %% def plot_2d_proj(pinn, n_visu=128, val_fix=0.5): # Visualize: plot 2D cuts fixing the other coordinates at 0.5 n_visu = 128 x1_lin = jnp.linspace(-1, 1, n_visu) x2_lin = jnp.linspace(-1, 1, n_visu) X1, X2 = jnp.meshgrid(x1_lin, x2_lin) x_flat = jnp.stack([X1.ravel(), X2.ravel()], axis=-1) # fix x3=0.5, x4=0.5 (x3=x4=0 makes sin(pi*0)=0 => exact sol identically zero) x_cut12 = jnp.concatenate( [x_flat, val_fix * jnp.ones((x_flat.shape[0], 2))], axis=-1 ) u_pred = pinn.evaluate(x_cut12) u_exact = exact_sol(x_cut12) u_pred = u_pred.reshape(n_visu, n_visu) u_exact = u_exact.reshape(n_visu, n_visu) error = jnp.abs(u_pred - u_exact) fig, axes = plt.subplots(1, 3, figsize=(15, 4)) im0 = axes[0].contourf(X1, X2, u_pred, levels=32, cmap="turbo") plt.colorbar(im0, ax=axes[0]) axes[0].set_title("PINN solution (x3=x4=0.5)") axes[0].set_xlabel("x1") axes[0].set_ylabel("x2") im1 = axes[1].contourf(X1, X2, u_exact, levels=32, cmap="turbo") plt.colorbar(im1, ax=axes[1]) axes[1].set_title("Exact solution (x3=x4=0.5)") axes[1].set_xlabel("x1") axes[1].set_ylabel("x2") im2 = axes[2].contourf(X1, X2, error, levels=32, cmap="inferno") plt.colorbar(im2, ax=axes[2]) axes[2].set_title(f"Absolute error (max={float(error.max()):.2e})") axes[2].set_xlabel("x1") axes[2].set_ylabel("x2") plt.suptitle("4D Laplacian — cut at x3=x4=0.5") plt.tight_layout() plt.show() # %% # L2 error on a random sample n_test = 5000 key = jax.random.PRNGKey(0) key, sample_test = sampler.sample(key, n_test) # sample_test keys are domain labels; each value is a tuple (x, mu) (x_test,) = next(iter(sample_test.values())) u_pred_test = pinn.evaluate(x_test) u_exact_test = exact_sol(x_test) l2_error = jnp.sqrt(jnp.mean((u_pred_test - u_exact_test) ** 2)) print(f"L2 error on {n_test} test points: {float(l2_error):.4e}") plot_2d_proj(pinn) sampler = TensorizedSampler([DomainSampler(dx)], bc=True) # %% print("@@@@@@@@@@@@@@@ create a PINN with weak BC (4D Laplacian) @@@@@@@@@@@@@@@@@@@@@") nn = MLP(in_size=dim, out_size=1, hidden_sizes=[24, 24, 24], key=key) space = ApproximationSpace( {"x": dim}, [(nn, "scalar", None)], model_type="x", ) model = LaplacianDirichletND(dx, lambda *args: f_rhs(*args), bc="weak") pinn = Projector(model, space, sampler, matrix_regularization=1e-4) losses = pinn.evaluate_losses(space, sampler.sample(key, N_COLLOC, N_BC_COLLOC)[1]) print("initial losses: ", losses) print("@@@@@@@@@@@@@@@ train with Energy Natural Gradient @@@@@@@@@@@@@@@@@@@@@") start = timeit.default_timer() key, pinn = pinn.project(key, space, N_EPOCHS, N_COLLOC, N_BC_COLLOC) end = timeit.default_timer() print("best loss: ", pinn.best_loss) print(f"time for {N_EPOCHS} epochs: {end - start:.2f}s") plot_2d_proj(pinn)