r"""Solves a time-independent Schrödinger equation in 2D. .. math:: -\Delta u & = f in \Omega \\ u & = g on \partial \Omega where :math:`u: \Omega \to \mathbb{C}` is the unknown function and :math:`\Omega \subset \mathbb{R}^2` is the spatial domain. The exact solution is .. math:: u(\mathbf{x}) = A \exp(- B \|\mathbf{x}\|^2) \exp(\mathrm{i} \langle{\mathbf{K}, \mathbf{x} \rangle}). The equation is solved on the unit disk; strong boundary conditions are used. """ from typing import Callable, Tuple import matplotlib.pyplot as plt import numpy as np import torch from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace from scimba_torch.approximation_space.nn_space import NNxSpace from scimba_torch.domain.meshless_domain.domain_2d import Disk2D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.elliptic_pde.pinns import ( NaturalGradientPinnsElliptic, ) from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import ( StrongFormEllipticPDE, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor from scimba_torch.utils.typing_protocols import VarArgCallable A_CST = 1 B_CST = 1 K_CST = [1, 1] class Schrödinger2D(StrongFormEllipticPDE): def __init__( self, space: AbstractApproxSpace, f: Callable, g: Callable, **kwargs, ): super().__init__( space, linear=True, residual_size=2, bc_residual_size=2, **kwargs ) self.space = space self.f = f self.g = g def grad( self, w: torch.Tensor | MultiLabelTensor, y: torch.Tensor | LabelTensor, ) -> torch.Tensor | Tuple[torch.Tensor, ...]: return self.space.grad(w, y) def rhs( self, w: MultiLabelTensor, x: LabelTensor, mu: LabelTensor ) -> Tuple[torch.Tensor]: return self.f(x, mu) def bc_rhs( self, w: MultiLabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor ) -> Tuple[torch.Tensor]: return self.g(x, mu) def operator( self, w: MultiLabelTensor, xs: LabelTensor, mu: LabelTensor ) -> Tuple[torch.Tensor]: x, y = xs.get_components() u_real, u_imag = w.get_components() u_real_x, u_real_y = self.grad(u_real, xs) u_real_xx, _ = self.grad(u_real_x, xs) _, u_real_yy = self.grad(u_real_y, xs) u_imag_x, u_imag_y = self.grad(u_imag, xs) u_imag_xx, _ = self.grad(u_imag_x, xs) _, u_imag_yy = self.grad(u_imag_y, xs) return -u_real_xx - u_real_yy, -u_imag_xx - u_imag_yy def functional_operator( self, func: VarArgCallable, x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: hess_func = torch.func.jacfwd(torch.func.jacrev(func, argnums=0), argnums=0) hess = hess_func(x, mu, theta) laplacian_real = hess[0, 0, 0] + hess[0, 1, 1] laplacian_imag = hess[1, 0, 0] + hess[1, 1, 1] return torch.stack([-laplacian_real, -laplacian_imag]) def bc_operator( self, w: MultiLabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor ) -> Tuple[torch.Tensor]: u, v = w.get_components() return u, v def functional_operator_bc( self, func: VarArgCallable, x: torch.Tensor, n: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: return func(x, mu, theta) def exact_solution(xs: LabelTensor, mu: LabelTensor) -> torch.Tensor: x, y = xs.get_components() arg = K_CST[0] * x + K_CST[1] * y real_part = torch.cos(arg) imag_part = torch.sin(arg) return ( A_CST * torch.exp(-B_CST * (x**2 + y**2)) * torch.cat([real_part, imag_part], dim=-1) ) def f_rhs(xs: LabelTensor, mu: LabelTensor) -> Tuple[torch.Tensor]: x, y = xs.get_components() kx, ky = K_CST exp = torch.exp(-B_CST * x**2 - B_CST * y**2) sin = torch.sin(ky * y + kx * x) cos = torch.cos(ky * y + kx * x) f_real = -( A_CST * exp * ( 4 * B_CST * ky * y * sin + 4 * B_CST * kx * x * sin + 4 * B_CST**2 * y**2 * cos + 4 * B_CST**2 * x**2 * cos - ky**2 * cos - kx**2 * cos - 4 * B_CST * cos ) ) f_imag = -( A_CST * exp * ( 4 * B_CST**2 * y**2 * sin + 4 * B_CST**2 * x**2 * sin - ky**2 * sin - kx**2 * sin - 4 * B_CST * sin - 4 * B_CST * ky * y * cos - 4 * B_CST * kx * x * cos ) ) return f_real, f_imag def f_bc(x: LabelTensor, mu: LabelTensor) -> Tuple[torch.Tensor]: ex = exact_solution(x, mu) return ex[..., 0:1], ex[..., 1:2] def post_processing(inputs: torch.Tensor, xs: LabelTensor, mu: LabelTensor): x, y = xs.get_components() arg = K_CST[0] * x + K_CST[1] * y real_part = torch.cos(arg) imag_part = torch.sin(arg) res = A_CST * np.exp(-B_CST) * torch.cat([real_part, imag_part], dim=-1) phi = x**2 + y**2 - 1 return res + inputs * phi def functional_post_processing( func, xs: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor ) -> torch.Tensor: x, y = xs[0], xs[1] arg = K_CST[0] * x + K_CST[1] * y real_part = torch.cos(arg) imag_part = torch.sin(arg) res = A_CST * np.exp(-B_CST) * torch.stack([real_part, imag_part]) phi = x**2 + y**2 - 1 return res + func(xs, mu, theta) * phi domain_mu = [] domain_x = Disk2D([0, 0], 1, is_main_domain=True) sampler = TensorizedSampler( [DomainSampler(domain_x), UniformParametricSampler(domain_mu)] ) space = NNxSpace( 2, 0, GenericMLP, domain_x, sampler, layer_sizes=[16, 16], post_processing=post_processing, ) pde = Schrödinger2D(space, f_rhs, f_bc) pinn = NaturalGradientPinnsElliptic( pde, bc_type="strong", one_loss_by_equation=True, matrix_regularization=1e-4, functional_post_processing=functional_post_processing, ) pinn.solve(epochs=25, n_collocation=3000) plot_abstract_approx_spaces( (pinn.space,), domain_x, domain_mu, loss=(pinn.losses,), residual=(pde,), error=exact_solution, draw_contours=True, n_drawn_contours=20, title="solving Schrödinger 2D", titles=("ENG preconditioning",), ) plt.show()