r"""Solves a 2D Poisson PDE with Dirichlet BCs using the Deep Ritz method and PINNs. .. math:: -\mu \Delta u & = f in \Omega \times M \\ u & = g on \partial \Omega \times M where :math:`x = (x_1, x_2) \in \Omega = \mathcal{D}` (with \mathcal{D} being the disk with center :math:`(x_1^0, x_2^0) = (-0.5, 0.5)`), :math:`f` such that :math:`u(x_1, x_2, \mu) = 0.25 \mu (1 - (x_1 - x_1^0)^2 - (x_2 - x_2^0)^2)`, :math:`g = 0` and :math:`\mu \in M = [0.5, 1]`. Boundary conditions are enforced strongly. """ import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.nn_space import NNxSpace from scimba_torch.domain.meshless_domain.domain_2d import Disk2D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.elliptic_pde.deep_ritz import ( DeepRitzElliptic, NaturalGradientDeepRitzElliptic, ) from scimba_torch.numerical_solvers.elliptic_pde.pinns import ( NaturalGradientPinnsElliptic, ) from scimba_torch.physical_models.elliptic_pde.laplacians import ( Laplacian2DDirichletRitzForm, Laplacian2DDirichletStrongForm, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor center = (-0.5, 0.5) def exact_sol(x: LabelTensor, mu: LabelTensor): x1, x2 = x.get_components() mu1 = mu.get_components() return 0.25 * mu1 * (1 - (x1 - center[0]) ** 2 - (x2 - center[1]) ** 2) def f_rhs(x: LabelTensor, mu: LabelTensor): # x1, x2 = x.get_components() mu1 = mu.get_components() return mu1**2 def f_bc(x: LabelTensor, mu: LabelTensor): x1, _ = x.get_components() # mu1 = mu.get_components() return torch.zeros_like(x1) domain_x = Disk2D(center, 1, is_main_domain=True) domain_mu = [[1.0, 1.0]] # domain_mu = [[0.5, 1.0]] sampler = TensorizedSampler( [DomainSampler(domain_x), UniformParametricSampler(domain_mu)] ) def post_processing(inputs: torch.Tensor, x: LabelTensor, mu: LabelTensor): x1, x2 = x.get_components() # mu1 = mu.get_components() phi = (x1 - center[0]) ** 2 + (x2 - center[1]) ** 2 - 1.0 return inputs * phi def functional_post_processing( func, x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor ) -> torch.Tensor: phi = (x[0] - center[0]) ** 2 + (x[1] - center[1]) ** 2 - 1.0 return func(x, mu, theta) * phi space = NNxSpace( 1, 1, GenericMLP, domain_x, sampler, layer_sizes=[64], post_processing=post_processing, ) pde = Laplacian2DDirichletStrongForm(space, f=f_rhs, g=f_bc) pinns = NaturalGradientPinnsElliptic( pde, bc_type="strong", functional_post_processing=functional_post_processing, ) new_solve = False if new_solve or not pinns.load(__file__, "pinn_strong_ENG"): pinns.solve(epochs=200, n_collocation=900, n_bc_collocation=200, verbose=True) pinns.save(__file__, "pinn_strong_ENG") ###### second space ############ space2 = NNxSpace( 1, 1, GenericMLP, domain_x, sampler, layer_sizes=[64], # layer_sizes=[20,20,20,20,20], post_processing=post_processing, ) pde2 = Laplacian2DDirichletRitzForm(space2, f=f_rhs, g=f_bc) ritz = DeepRitzElliptic( pde2, bc_type="strong", ) new_solve = False if new_solve or not ritz.load(__file__, "ritz_strong_no_precond"): ritz.solve(epochs=2000, n_collocation=40000, n_bc_collocation=1000, verbose=True) ritz.save(__file__, "ritz_strong_no_precond") space3 = NNxSpace( 1, 1, GenericMLP, domain_x, sampler, layer_sizes=[64], # layer_sizes=[20,20,20,20,20], # activation_type="sinh", post_processing=post_processing, ) pde3 = Laplacian2DDirichletRitzForm(space3, f=f_rhs, g=f_bc) ritz2 = NaturalGradientDeepRitzElliptic( pde3, bc_type="strong", functional_post_processing=functional_post_processing, ) new_solve = False if new_solve or not ritz2.load(__file__, "ritz_strong_ENG"): ritz2.solve(epochs=200, n_collocation=40000, n_bc_collocation=200, verbose=True) ritz2.save(__file__, "ritz_strong_ENG") # bc_weight = 15.0 # # space3 = NNxSpace( # 1, # 1, # GenericMLP, # domain_x, # sampler, # layer_sizes=[64], # # layer_sizes=[20,20,20,20,20], # # post_processing=post_processing, # ) # # pde3 = Laplacian2DDirichlet_RitzForm(space3, f=f_rhs, g=f_bc) # # ritz2 = DeepRitzElliptic( # pde3, bc_type="weak", bc_weight = bc_weight, # # optimizers=OptimizerData(opt_2), # # functional_post_processing=functional_post_processing # ) # # new_solve = True # if new_solve or not ritz2.load(__file__, "ritz_weak_no_precond"): # ritz2.solve(epochs=2000, n_collocation=3000, n_bc_collocation=1000, verbose=True) # ritz2.save(__file__, "ritz_weak_no_precond") # ###### second space ############ # space2 = NNxSpace( # 1, # 1, # GenericMLP, # domain_x, # sampler, # layer_sizes=[64], # # post_processing=post_processing, # ) # # pde2 = Laplacian2DDirichlet_StrongForm(space2, f=f_rhs, g=f_bc) # # pinns2 = NaturalGradientPinnsElliptic( # pde2, bc_type="weak", # # functional_post_processing=functional_post_processing # ) # # new_solve = True # if new_solve or not pinns2.load(__file__, "pinn_weak_ENG"): # pinns2.solve(epochs=200, n_collocation=900, n_bc_collocation=200, verbose=True) # pinns2.save(__file__, "pinn_weak_ENG") plot_abstract_approx_spaces( ( pinns.space, ritz.space, ritz2.space, ), # an Iterable of AbstractSpace ( domain_x, ), # either a VolumetricDomain, or an Iterable of VolumetricDomain of length 1 or len(first argument) (domain_mu,), loss=( pinns.losses, ritz.losses, ritz2.losses, ), # same as previously; if only one is given, it will be used for all spaces # residual=( # pinns.pde, # ritz.pde, # ), # same as previously; if only one is given, it will be used for all spaces error=exact_sol, draw_contours=True, n_drawn_contours=20, parameters_values="mean", title=r"Solving $-\mu\Delta u = \mu^2$ on the unit disk", titles=( "PINN with ENG preconditioning", "RITZ with no preconditioning", "RITZ with ENG preconditioning", ), ) plt.show() # plot_AbstractApproxSpaces( # pinns2.space, # domain_x, # a VolumetricDomain, # [(1.0, 2.0)], # List[Sequence[float]] # parameters_values=([1.0], [1.5], [2.0]), # draw_contours=True, # n_drawn_contours=20, # ) # plt.show()