r"""Solves a 2D Poisson PDE with Dirichlet boundary conditions using PINNs. .. math:: -\mu \delta u & = f in \Omega \times M \\ u & = g on \partial \Omega \times M where :math:`x = (x_1, x_2) \in \Omega = (0, 1) \times (0, 1)`, :math:`f` such that :math:`u(x_1, x_2, \mu) = \mu \sin(2\pi x_1) \sin(2\pi x_2)`, :math:`g = 0` and :math:`\mu \in M = [1, 2]`. Boundary conditions are enforced either weakly or strongly. The neural network used is a simple MLP (Multilayer Perceptron), and the optimization is done using Adam. """ # %% import matplotlib.pyplot as plt import numpy as np import torch from scimba_torch.approximation_space.nn_space import NNxSpace from scimba_torch.domain.meshless_domain.domain_2d import Square2D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.elliptic_pde.pinns import ( NaturalGradientPinnsElliptic, PinnsElliptic, ) from scimba_torch.optimizers.losses import DataLoss from scimba_torch.physical_models.elliptic_pde.laplacians import ( Laplacian2DDirichletStrongForm, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor torch.manual_seed(0) bc_weight = 40.0 def f_rhs(x: LabelTensor, mu: LabelTensor): x1, x2 = x.get_components() mu1 = mu.get_components() return ( mu1 * mu1 * 8.0 * torch.pi * torch.pi * torch.sin(2.0 * torch.pi * x1) * torch.sin(2.0 * torch.pi * x2) ) def exact_sol(x: LabelTensor, mu: LabelTensor): x1, x2 = x.get_components() mu1 = mu.get_components() return mu1 * torch.sin(2.0 * torch.pi * x1) * torch.sin(2.0 * torch.pi * x2) box_x = [(0.0, 1.0), (0.0, 1.0)] parameters_domain = [(1.0, 2.0)] domain_x = Square2D(box_x, is_main_domain=True) sampler = TensorizedSampler( [DomainSampler(domain_x), UniformParametricSampler(parameters_domain)] ) # %% ### generate data n_samples_per_dim = 40 box = box_x + parameters_domain linspaces = [np.linspace(b[0], b[1], n_samples_per_dim) for b in box] meshgrid = np.meshgrid(*linspaces) mesh = torch.tensor(np.stack(meshgrid, axis=-1).reshape((n_samples_per_dim**3, 3))) # print(mesh) x_data = LabelTensor(mesh[:, :2]) mu_data = LabelTensor(mesh[:, 2:3]) y_data = exact_sol(x_data, mu_data) # print(y_data) dloss = DataLoss((mesh[:, :2], mesh[:, 2:3]), y_data, torch.nn.MSELoss()) space = NNxSpace(1, 1, GenericMLP, domain_x, sampler, layer_sizes=[40]) pde = Laplacian2DDirichletStrongForm(space, f=f_rhs) pinns = PinnsElliptic( pde, bc_type="weak", bc_weight=bc_weight, data_losses=[dloss], dl_weights=[1.0], optimizers="ssbfgs", ) new_solve = True if new_solve or not pinns.load(__file__, "ssbfgs"): pinns.solve(epochs=1000, n_collocation=3000, n_bc_collocation=1600) pinns.save( __file__, "ssbfgs", ) space2 = NNxSpace(1, 1, GenericMLP, domain_x, sampler, layer_sizes=[40]) pde2 = Laplacian2DDirichletStrongForm(space2, f=f_rhs) pinns2 = NaturalGradientPinnsElliptic( pde2, bc_type="weak", bc_weight=bc_weight, data_losses=[dloss], dl_weights=[1.0], ) new_solve = True if new_solve or not pinns2.load(__file__, "ENG"): pinns2.solve(epochs=200, n_collocation=3000, n_bc_collocation=1600) pinns2.save( __file__, "ENG", ) space3 = NNxSpace(1, 1, GenericMLP, domain_x, sampler, layer_sizes=[40]) pde3 = Laplacian2DDirichletStrongForm(space3, f=f_rhs) pinns3 = NaturalGradientPinnsElliptic( pde3, bc_type="weak", bc_weight=bc_weight, data_losses=[dloss], dl_weights=[1.0], ng_algo="ANaGRAM", svd_threshold=5e-1, ) new_solve = True if new_solve or not pinns3.load(__file__, "ANaGRAM"): pinns3.solve(epochs=200, n_collocation=3000, n_bc_collocation=1600) pinns3.save( __file__, "ANaGRAM", ) # for plotting pinns plot_abstract_approx_spaces( (pinns.space, pinns2.space, pinns3.space), # the approximation space domain_x, # the spatial domain parameters_domain, # the parameter's domain loss=( pinns.losses, pinns2.losses, pinns3.losses, ), # for plot of the loss: the losses residual=(pde, pde2, pde3), # for plot of the residual: the pde # solution=exact_sol, # for plot of the exact sol: sol error=exact_sol, # for plot of the error with respect to a func: the func draw_contours=True, n_drawn_contours=20, parameters_values="random", ) plt.show()