r"""Solves a grad-div system in 2D with Dirichlet boundary conditions using a PINN. .. math:: -\nabla (\nabla \cdot u) + u & = f in \Omega \times M \\ u & = g on \partial \Omega \times M where :math:`u: \Omega \times M \to \mathbb{R}^2` is the unknown function, :math:`\Omega \subset \mathbb{R}^2` is the spatial domain and :math:`M \subset \mathbb{R}` is the parametric domain. The equation is solved on a square domain; weak boundary conditions are used. PINNs are used with energy natural gradient preconditioning, and four sampling streategies are compared (resample at each epoch, or sample from a set of pre-sampled points in the domain, on the boundary, or both). """ # %% from typing import Callable, Tuple import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace from scimba_torch.approximation_space.nn_space import NNxSpace from scimba_torch.domain.meshless_domain.domain_2d import Square2D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.elliptic_pde.pinns import ( NaturalGradientPinnsElliptic, ) from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import ( StrongFormEllipticPDE, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor from scimba_torch.utils.typing_protocols import VarArgCallable class GradDiv2D(StrongFormEllipticPDE): def __init__( self, space: AbstractApproxSpace, f: Callable, g: Callable, **kwargs, ): super().__init__( space, linear=True, residual_size=2, bc_residual_size=2, **kwargs ) self.space = space self.f = f self.g = g def grad( self, w: torch.Tensor | MultiLabelTensor, y: torch.Tensor | LabelTensor, ) -> torch.Tensor | Tuple[torch.Tensor, ...]: return self.space.grad(w, y) def rhs( self, w: MultiLabelTensor, x: LabelTensor, mu: LabelTensor ) -> Tuple[torch.Tensor]: return self.f(x, mu) def bc_rhs( self, w: MultiLabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor ) -> Tuple[torch.Tensor]: return self.g(x, mu) def operator( self, w: MultiLabelTensor, xs: LabelTensor, mu: LabelTensor ) -> Tuple[torch.Tensor]: x, y = xs.get_components() u, v = w.get_components() u_x, u_y = self.grad(u, xs) u_xx, u_xy = self.grad(u_x, xs) v_x, v_y = self.grad(v, xs) v_yx, v_yy = self.grad(v_y, xs) return u_xx + v_yx + u, u_xy + v_yy + v def restrict_to_component(self, i: int, func): return lambda *args: func(*args)[i : i + 1, ...] def functional_operator( self, func: VarArgCallable, x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: uv_x = func(x, mu, theta) grad_u = self.restrict_to_component(0, torch.func.jacrev(func, 0)) grad_v = self.restrict_to_component(1, torch.func.jacrev(func, 0)) hessian_u = torch.func.jacrev(grad_u, 0)(x, mu, theta).squeeze() hessian_v = torch.func.jacrev(grad_v, 0)(x, mu, theta).squeeze() res = hessian_u[..., 0] + hessian_v[..., 1] + uv_x return res # Dirichlet conditions def bc_operator( self, w: MultiLabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor ) -> Tuple[torch.Tensor]: u, v = w.get_components() return u, v def functional_operator_bc( self, func: VarArgCallable, x: torch.Tensor, n: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: return func(x, mu, theta) def exact_solution(xs: LabelTensor, mu: LabelTensor) -> torch.Tensor: x, y = xs.get_components() alpha = mu.get_components() return torch.cat( ( torch.sin(2.0 * torch.pi * x) * torch.sin(2.0 * torch.pi * y), alpha * torch.sin(2.0 * torch.pi * x) * torch.sin(2.0 * torch.pi * y), ), dim=-1, ) def f_rhs(xs: LabelTensor, mu: LabelTensor) -> Tuple[torch.Tensor]: x, y = xs.get_components() alpha = mu.get_components() PI = torch.pi cos_x = torch.cos(2.0 * PI * x) cos_y = torch.cos(2.0 * PI * y) sin_x = torch.sin(2.0 * PI * x) sin_y = torch.sin(2.0 * PI * y) f1 = (1 - 4 * PI**2) * sin_x * sin_y + 4 * PI**2 * alpha * cos_x * cos_y f2 = (1 - 4 * PI**2 * alpha) * sin_x * sin_y + 4 * PI**2 * cos_x * cos_y return f1, f2 def f_bc(xs: LabelTensor, mu: LabelTensor) -> Tuple[torch.Tensor]: x, _ = xs.get_components() return torch.zeros_like(x), torch.zeros_like(x) bc_weight = 10.0 domain_mu = [(0.75, 0.75 + 1e-4)] domain_x = Square2D([(0.0, 1), (0.0, 1)], is_main_domain=True) def plot_pinn(pinn, title): plot_abstract_approx_spaces( [pinn.space], (domain_x,), (domain_mu,), loss=[pinn.losses], residual=[pinn.pde], draw_contours=True, n_drawn_contours=20, title=title, ) plt.show() N_COLLOCATION = 2_500 N_BC_COLLOCATION = 1_000 N_PRE_SAMPLE = 25_000 N_PRE_SAMPLE_BC = 10_000 N_EPOCHS = 100 def create_and_plot(sampler_x, title): sampler = TensorizedSampler([sampler_x, UniformParametricSampler(domain_mu)]) space = NNxSpace(2, 1, GenericMLP, domain_x, sampler, layer_sizes=[64]) pde = GradDiv2D(space, f_rhs, f_bc) pinn = NaturalGradientPinnsElliptic( pde, bc_type="weak", bc_weight=bc_weight, one_loss_by_equation=True, matrix_regularization=1e-6, ) pinn.solve( epochs=N_EPOCHS, n_collocation=N_COLLOCATION, n_bc_collocation=N_BC_COLLOCATION, ) plot_pinn(pinn, title) # %% 1/ RE-SAMPLE POINTS AT EACH EPOCH sampler_x = DomainSampler(domain_x) create_and_plot(sampler_x, "re-sampling at each epoch") # %% 2/ DEFINE SOME PRE-SAMPLED POINTS IN DOMAIN sampler_x = DomainSampler( domain_x, pre_sampling=True, n_pre_sampled_points=N_PRE_SAMPLE, ) create_and_plot(sampler_x, "pre-sampling in domain") # %% 3/ DEFINE SOME PRE-SAMPLED POINTS ON BOUNDARY sampler_x = DomainSampler( domain_x, pre_sampling_bc=True, n_pre_sampled_points_bc=N_PRE_SAMPLE_BC, ) create_and_plot(sampler_x, "pre-sampling on boundary") # %% 4/ DEFINE SOME PRE-SAMPLED POINTS IN DOMAIN AND ON BOUNDARY sampler_x = DomainSampler( domain_x, pre_sampling=True, n_pre_sampled_points=N_PRE_SAMPLE, pre_sampling_bc=True, n_pre_sampled_points_bc=N_PRE_SAMPLE_BC, ) create_and_plot(sampler_x, "pre-sampling in domain and on boundary") # %%