r"""Solves a 1D Poisson PDE with Dirichlet boundary conditions using PINNs. .. math:: -\delta u & = f in \Omega where :math:`x \in \Omega = (0, 1)` and :math:`f` such that :math:`u(x) = \sin(\pi x)`, and :math:`g = 0`. The boundary conditions are Dirichlet conditions enforced strongly. The neural network is a simple MLP (Multilayer Perceptron). The optimization is done using a classical PINN with Natural Gradient Descent, then with a FBPINN Natural Gradient Descent. """ # %% import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, FiniteBasisApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.elliptic_pde.laplacians import LaplacianDirichletND from scimba_jax.plots.plots_nd import ( plot_abstract_approx_spaces, ) N_COLLOC = 1000 N_BC_COLLOC = 2000 N_EPOCHS = 50 N_EPOCHS_ADAM = 1000 def f_rhs(x: jnp.ndarray, mu: jnp.ndarray): return jnp.pi**2 * jnp.sin(jnp.pi * x) def f_bc(x: jnp.ndarray, n: jnp.ndarray, mu: jnp.ndarray): return x * 0 def exact_sol(x: jnp.ndarray, mu: jnp.ndarray): return jnp.sin(jnp.pi * x) def post_processing(approx: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray): return approx * x * (1 - x) domain_x = (0, 1) domain_mu = [] dx = Segment1D(domain_x, is_main_domain=True) sampler = TensorizedSampler( [DomainSampler(dx), UniformParametricSampler(domain_mu)], bc=False ) # %% print("\n\n") print("@@@@@@@@@@@@@@@ training a classic PINN @@@@@@@@@@@@@@@@@@@@@@") key = jax.random.PRNGKey(0) nn = MLP(in_size=1, out_size=1, hidden_sizes=[32, 32], key=key) space = ApproximationSpace( {"x": 1}, [(nn, "scalar", None)], model_type="x_mu", post_processing=post_processing ) model = LaplacianDirichletND(dx, f_rhs, bc="strong") pinn = Projector(model, space, sampler, linesearch="armijo") start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn.project(space, key, N_EPOCHS, N_COLLOC) pinn.losses.loss_history = loss_history pinn.space = nspace pinn.best_loss = new_loss end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS, end - start) # %% print("\n\n") print("@@@@@@@@@@@@@@@ training an FBPINN @@@@@@@@@@@@@@@@@@@@@@") key = jax.random.PRNGKey(0) n_subdomains = 4 overlap = 1.5 def window_function_cos(x, i): center = (i + 1 / 2) / n_subdomains width = (1 / n_subdomains) * overlap / 2 in_subdomain = (center - width < x) & (x < center + width) return (1 + jnp.cos(jnp.pi * (x - center) / width)) ** 2 * in_subdomain def window_function(x, i): sum_windows = 0 for j in range(n_subdomains): sum_windows += window_function_cos(x, j) return window_function_cos(x, i) / sum_windows if plot_window_functions := False: x = jnp.linspace(0, 1, 100) for i in range(n_subdomains): plt.plot(x, window_function(x, i), label=f"window {i}") sum_windows = 0 for i in range(n_subdomains): sum_windows += window_function(x, i) plt.plot(x, sum_windows, label="sum of windows", linestyle="--") plt.legend() dx = Segment1D(domain_x, is_main_domain=True) # for i in range(n_subdomains): # center = (i + 1 / 2) / n_subdomains # width = (1 / n_subdomains) * overlap / 2 # x_min = max(0, center - width) # x_max = min(1, center + width) # subdomain = Segment1D( # (x_min, x_max), is_main_domain=False, label_str="subdomain", label_idx=i # ) # dx.add_subdomain(subdomain) sampler_x = DomainSampler(dx) sampler_mu = UniformParametricSampler(domain_mu) sampler = TensorizedSampler([sampler_x, sampler_mu], bc=False) key = jax.random.PRNGKey(0) nn = [ MLP(in_size=1, out_size=1, hidden_sizes=[16, 16], key=key) for _ in range(n_subdomains) ] space2 = FiniteBasisApproximationSpace( {"x": 1}, [(nn, "scalar", None)], window_function=window_function, model_type="x_mu", post_processing=post_processing, ) model = LaplacianDirichletND(dx, f_rhs, bc="strong") key, sample_dict = sampler.sample(key, N_COLLOC) pinn2 = Projector( model, space2, sampler, linesearch="armijo", block_diagonal_preconditioning=True, n_subdomains=n_subdomains, truncate_jacobian_svd=False, truncate_jacobian_svd_threshold=0.02, ) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn2.project(space2, key, N_EPOCHS, N_COLLOC) pinn2.losses.loss_history = loss_history pinn2.best_loss = new_loss pinn2.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS, end - start) plot_abstract_approx_spaces( (pinn.space, pinn2.space), dx, domain_mu, loss=(pinn.losses, pinn2.losses), residual=(pinn.model, pinn2.model), solution=exact_sol, error=exact_sol, title="learning sol of 1D laplacian", titles=("with PINN", "with FBPINN"), ) plt.show() pinn2.plot_gram_matrix() # %%