r"""Solves a 1D Poisson PDE with Dirichlet boundary conditions using PINNs. .. math:: -\delta u & = f in \Omega where :math:`x \in \Omega = (0, 1)` and :math:`f` such that :math:`u(x) = \sin(\pi x)`, and :math:`g = 0`. The boundary conditions are Dirichlet conditions enforced strongly. The neural network is a simple MLP (Multilayer Perceptron). The optimization is done using a classical PINN with Natural Gradient Descent, then with a FBPINN Natural Gradient Descent. """ # %% import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( FiniteBasisApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.elliptic_pde.laplacians import LaplacianDirichletND from scimba_jax.plots.plots_nd import ( plot_abstract_approx_spaces, ) N_COLLOC = 1000 N_BC_COLLOC = 200 N_EPOCHS = 50 HIDDEN_SIZES = [16, 16] def f_rhs(x: jnp.ndarray): return jnp.pi**2 * jnp.sin(jnp.pi * x) def f_bc(x: jnp.ndarray, n: jnp.ndarray): return x * 0 def exact_sol(x: jnp.ndarray): return jnp.sin(jnp.pi * x) def post_processing(approx: jnp.ndarray, x: jnp.ndarray): return approx * x * (1 - x) domain_x = (0, 1) dx = Segment1D(domain_x, is_main_domain=True) sampler = TensorizedSampler([DomainSampler(dx)], bc=False) key = jax.random.PRNGKey(0) n_subdomains = 4 overlap = 1.5 print("\n\n") print( "@@@@@@@@@@@@@@@ %d subdomains, %.1f overlap @@@@@@@@@@@@@@@@@@@@@@" % (n_subdomains, overlap) ) def window_function_cos(x, i): center = (i + 1 / 2) / n_subdomains width = (1 / n_subdomains) * overlap / 2 in_subdomain = (center - width < x) & (x < center + width) return (1 + jnp.cos(jnp.pi * (x - center) / width)) ** 2 * in_subdomain def window_function(x, i): sum_windows = 0 for j in range(n_subdomains): sum_windows += window_function_cos(x, j) return window_function_cos(x, i) / sum_windows if plot_window_functions := True: x = jnp.linspace(0, 1, 100) for i in range(n_subdomains): plt.plot(x, window_function(x, i), label=f"window {i}") sum_windows = 0 for i in range(n_subdomains): sum_windows += window_function(x, i) plt.plot(x, sum_windows, label="sum of windows", linestyle="--") plt.title("Window functions: %d subdomains, %.1f overlap" % (n_subdomains, overlap)) plt.legend() plt.show() dx = Segment1D(domain_x, is_main_domain=True) # for i in range(n_subdomains): # center = (i + 1 / 2) / n_subdomains # width = (1 / n_subdomains) * overlap / 2 # x_min = max(0, center - width) # x_max = min(1, center + width) # subdomain = Segment1D( # (x_min, x_max), is_main_domain=False, label_str="subdomain", label_idx=i # ) # dx.add_subdomain(subdomain) sampler_x = DomainSampler(dx) sampler = TensorizedSampler([sampler_x], bc=False) print("\n\n") print( "@@@@@@@@@@@@@@@ training an FBPINN, ng with Gram matrix computation and bloc diagonal precond @@@@@@@@@@@@@@@@@@@@@@" ) key = jax.random.PRNGKey(0) nn = [ MLP(in_size=1, out_size=1, hidden_sizes=HIDDEN_SIZES, key=key) for _ in range(n_subdomains) ] space = FiniteBasisApproximationSpace( {"x": 1}, [(nn, "scalar", None)], window_function=window_function, model_type="x", post_processing=post_processing, ) model = LaplacianDirichletND(dx, f_rhs, bc="strong") key, sample_dict = sampler.sample(key, N_COLLOC) pinn = Projector( model, space, sampler, block_diagonal_preconditioning=True, n_subdomains=n_subdomains, truncate_jacobian_svd=False, truncate_jacobian_svd_threshold=0.02, ) start = timeit.default_timer() key, pinn = pinn.project(key, space, N_EPOCHS, N_COLLOC, verbose=False) end = timeit.default_timer() print("best loss: ", pinn.best_loss) print("time for %d epochs: " % N_EPOCHS, end - start) print("\n\n") print( "@@@@@@@@@@@@@@@ training an FBPINN with matrix free ng without bloc diagonal precond @@@@@@@@@@@@@@@@@@@@@@" ) key = jax.random.PRNGKey(0) nn2 = [ MLP(in_size=1, out_size=1, hidden_sizes=HIDDEN_SIZES, key=key) for _ in range(n_subdomains) ] space2 = FiniteBasisApproximationSpace( {"x": 1}, [(nn2, "scalar", None)], window_function=window_function, model_type="x", post_processing=post_processing, ) model = LaplacianDirichletND(dx, f_rhs, bc="strong") key, sample_dict = sampler.sample(key, N_COLLOC) pinn2 = Projector( model, space2, sampler, # linesearch="armijo", optimizer="MFENG", matrix_free=True, debug=False, rtol=1e-10, block_diagonal_preconditioning=False, ) start = timeit.default_timer() key, pinn2 = pinn2.project( key, space2, N_EPOCHS, N_COLLOC, verbose=False, ) end = timeit.default_timer() print("best loss: ", pinn2.best_loss) print("time for %d epochs: " % N_EPOCHS, end - start) print("nb cg it: ", pinn2.optimizer.total_cg_exit) print("\n\n") print( "@@@@@@@@@@@@@@@ training an FBPINN with matrix free ng without bloc diagonal precond @@@@@@@@@@@@@@@@@@@@@@" ) key = jax.random.PRNGKey(0) nn3 = [ MLP(in_size=1, out_size=1, hidden_sizes=HIDDEN_SIZES, key=key) for _ in range(n_subdomains) ] space3 = FiniteBasisApproximationSpace( {"x": 1}, [(nn3, "scalar", None)], window_function=window_function, model_type="x", post_processing=post_processing, ) model = LaplacianDirichletND(dx, f_rhs, bc="strong") key, sample_dict = sampler.sample(key, N_COLLOC) pinn3 = Projector( model, space3, sampler, # linesearch="armijo", optimizer="MFENG", matrix_free=True, debug=False, rtol=1e-10, block_diagonal_preconditioning=True, n_subdomains=n_subdomains, truncate_jacobian_svd=False, truncate_jacobian_svd_threshold=0.02, ) start = timeit.default_timer() key, pinn3 = pinn3.project(key, space3, N_EPOCHS, N_COLLOC, verbose=False) end = timeit.default_timer() print("best loss: ", pinn3.best_loss) print("time for %d epochs: " % N_EPOCHS, end - start) print("nb cg it: ", pinn3.optimizer.total_cg_exit) plot_abstract_approx_spaces( (pinn.space, pinn2.space, pinn3.space), dx, loss=(pinn.losses, pinn2.losses, pinn3.losses), residual=(pinn.model, pinn2.model, pinn3.model), solution=exact_sol, error=exact_sol, title="learning sol of 1D laplacian", titles=( "NG with Gram matrix computation, with bloc diagonal preconditioning", "NG matrix free, without bloc diagonal preconditioning", "NG matrix free, with bloc diagonal preconditioning", ), ) plt.show() pinn.plot_gram_matrix() # %%