r"""Solves a 2D Poisson PDE with Dirichlet and Neumann BCs using PINNs. .. math:: -\delta u & = f in \Omega \times M \\ u = g_D & on \Gamma_D \\ \frac{\partial u}{\partial n} = g_N & on \Gamma_N where :math:`x = (x_1, x_2) \in \Omega = \mathcal{D}` (with :math:`\mathcal{D} = [0,1]^2`), :math:`f = 1`, :math:`g_D = \mu` on the Dirichlet boundary, :math:`g_N = 0` on the Neumann boundary, and :math:`\mu \in M = [1, 2]`. The mixed boundary conditions (Neumann on the top boundary of the square, Dirichlet on the other ones) are enforced weakly. The neural network used is a simple MLP (Multilayer Perceptron), and the optimization is done using SSBFGS and Natural Gradient Descent (ENG). """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.model_class.funcparam_scalar import ( ParamScalarFunction, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.abstract_physical_model import AbstractPhysicalModel from scimba_jax.physical_models.abstract_residuals import InteriorResidual from scimba_jax.physical_models.boundary_residuals import ( DirichletResidual, NeumannResidual, ) from scimba_jax.plots.plots_nd import ( plot_abstract_approx_spaces, ) # N_COLLOC = 2000 N_BC_COLLOC = 3000 N_EPOCHS_ENG = 200 N_EPOCHS_SSBFGS = 1500 key = jax.random.PRNGKey(0) def f_rhs(xy: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray: return jnp.ones_like(xy[0:1]) def f_bc_rhs_d(x, n, mu): return mu[0:1] * jnp.ones_like(x[..., 0:1]) def f_bc_rhs_n(x, n, mu): return jnp.zeros_like(x[..., 0:1]) f_bc_rhs = {"bc D": f_bc_rhs_d, "bc N": f_bc_rhs_n} class ParametricLaplacianResidual(InteriorResidual): def __init__(self, domain, f_rhs): super().__init__(domain=domain, size=1, model_type="x_mu", f_rhs=f_rhs) def construct_residual(self, *vars): rho = vars[0] assert isinstance(rho, ParamScalarFunction) return -rho.laplacian_x() class MixedBCLaplacian(AbstractPhysicalModel): def __init__(self, main_domain, f_rhs, f_bc_rhs): super().__init__(main_domain=main_domain) self.physical_residuals = { "interior": ParametricLaplacianResidual(main_domain, f_rhs), "bc D": DirichletResidual( domain=self.boundaries["bc D"], f_rhs=f_bc_rhs["bc D"] ), "bc N": NeumannResidual( domain=self.boundaries["bc N"], f_rhs=f_bc_rhs["bc N"] ), } domain_mu = [(1.0, 2.0)] domain_x = Square2D([[0, 1], [0, 1]], is_main_domain=True) domain_x.set_boundaries_dict( { "bc N": ["bc north"], "bc D": ["bc east", "bc south", "bc west"], } ) sampler = TensorizedSampler( [ DomainSampler(domain_x), UniformParametricSampler(domain_mu), ], bc=True, ) key = jax.random.PRNGKey(0) nn = MLP(in_size=3, out_size=1, hidden_sizes=[20] * 3, key=key) space = ApproximationSpace({"x": 2, "mu": 1}, [(nn, "scalar", 1)], model_type="x_mu") model = MixedBCLaplacian(domain_x, f_rhs, f_bc_rhs) weights = {"interior": [1.0], "bc N": [30.0], "bc D": [30.0]} print("\n\n") print("@@@@@@@@@@@@@@@ test SS-BFGS @@@@@@@@@@@@@@@@@@@@@@") pinn = Projector( model, space, sampler, optimizer="SS-BFGS", weights=weights, ) key, sample_dict = sampler.sample(key, N_COLLOC, N_BC_COLLOC) loss0 = pinn.evaluate_loss(space, sample_dict) print("initial loss: ", loss0) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn.project( space, key, N_EPOCHS_SSBFGS, N_COLLOC, N_BC_COLLOC ) pinn.losses.loss_history = loss_history pinn.best_loss = new_loss pinn.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS_SSBFGS, end - start) print("\n\n") print("@@@@@@@@@@@@@@@ test ENG @@@@@@@@@@@@@@@@@@@@@@") nn = MLP(in_size=3, out_size=1, hidden_sizes=[20] * 3, key=key) space2 = ApproximationSpace({"x": 2, "mu": 1}, [(nn, "scalar", 1)], model_type="x_mu") pinn2 = Projector( model, space2, sampler, matrix_reguarization=5e-4, weights=weights, ) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn2.project( space2, key, N_EPOCHS_ENG, N_COLLOC, N_BC_COLLOC ) pinn2.losses.loss_history = loss_history pinn2.best_loss = new_loss pinn2.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS_ENG, end - start) plot_abstract_approx_spaces( (pinn.space, pinn2.space), # the approximation space domain_x, domain_mu, loss=(pinn.losses, pinn2.losses), residual=(pinn.model, pinn2.model), draw_contours=True, n_drawn_contours=20, ) plt.show()