Deep Ritz: div A grad u 2DΒΆ

Solves a 2D elliptic PDE with Dirichlet BCs using the Deep Ritz method and PINNs.

\[\begin{split}\left\{\begin{array}{rl}-\mu \nabla \cdot (A \nabla u) & = f \text{ in } \Omega \times M \\ u & = g \text{ in } \partial \Omega \times M\end{array}\right.\end{split}\]

where \(x = (x_1, x_2) \in \Omega = (0, 1) \times (0, 1)\), \(A = [[1, 1], [1, 1]]\),

\(f\) such that \(u(x_1, x_2, \mu) = \mu \sin(2\pi x_1) \sin(2\pi x_2)\), \(g = 0\) and \(\mu \in M = [0.5, 1]\).

Boundary conditions are enforced strongly.

[1]:
import matplotlib.pyplot as plt
import torch

from scimba_torch.approximation_space.nn_space import NNxSpace
from scimba_torch.domain.meshless_domain.domain_2d import Square2D
from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler
from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler
from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP
from scimba_torch.numerical_solvers.elliptic_pde.deep_ritz import (
    DeepRitzElliptic,
    NaturalGradientDeepRitzElliptic,
)
from scimba_torch.numerical_solvers.elliptic_pde.pinns import (
    NaturalGradientPinnsElliptic,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
    DivAGradUPDE,
)
from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces
from scimba_torch.utils.scimba_tensors import LabelTensor

def exact_sol(x: LabelTensor, mu: LabelTensor):
    x1, x2 = x.get_components()
    mu1 = mu.get_components()
    return mu1 * torch.sin(2.0 * torch.pi * x1) * torch.sin(2.0 * torch.pi * x2)


def f_rhs(x: LabelTensor, mu: LabelTensor):
    x1, x2 = x.get_components()
    mu1 = mu.get_components()
    return (
        8.0
        * mu1
        * mu1
        * (torch.pi**2)
        * (
            torch.sin(2.0 * torch.pi * x1) * torch.sin(2.0 * torch.pi * x2)
            - torch.cos(2.0 * torch.pi * x1) * torch.cos(2.0 * torch.pi * x2)
        )
    )


def f_bc(x: LabelTensor, mu: LabelTensor):
    x1, _ = x.get_components()
    # mu1 = mu.get_components()
    return torch.zeros_like(x1)


def A(x: torch.Tensor) -> torch.Tensor:  # noqa: N802
    return torch.ones(
        x.shape[0],
        2,
        2,
        dtype=torch.get_default_dtype(),
        device=torch.get_default_device(),
    )


domain_x = Square2D([(0.0, 1), (0.0, 1)], is_main_domain=True)
# domain_mu= []
domain_mu = [[0.5, 1.0]]
sampler = TensorizedSampler(
    [DomainSampler(domain_x), UniformParametricSampler(domain_mu)]
)


def post_processing(inputs: torch.Tensor, x: LabelTensor, mu: LabelTensor):
    x1, x2 = x.get_components()
    # mu1 = mu.get_components()
    return inputs * x1 * (1.0 - x1) * x2 * (1.0 - x2)


def functional_post_processing(
    func, x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor
) -> torch.Tensor:
    phi = x[0] * (1.0 - x[0]) * x[1] * (1.0 - x[1])
    return func(x, mu, theta) * phi
[2]:
space = NNxSpace(
    1,
    1,
    GenericMLP,
    domain_x,
    sampler,
    layer_sizes=[64],
    post_processing=post_processing,
)

pde = DivAGradUPDE(space, 2, f=f_rhs, g=f_bc, A=A)

pinns = NaturalGradientPinnsElliptic(
    pde,
    bc_type="strong",
    functional_post_processing=functional_post_processing,
)

pinns.solve(epochs=200, n_collocation=900, verbose=False)

space2 = NNxSpace(
    1,
    1,
    GenericMLP,
    domain_x,
    sampler,
    layer_sizes=[64],
    post_processing=post_processing,
)

pde2 = DivAGradUPDE(space2, 2, f=f_rhs, g=f_bc, A=A)

ritz = DeepRitzElliptic(
    pde2,
    bc_type="strong",
)

ritz.solve(epochs=3000, n_collocation=40000, verbose=False)

space3 = NNxSpace(
    1,
    1,
    GenericMLP,
    domain_x,
    sampler,
    layer_sizes=[64],
    post_processing=post_processing,
)

pde3 = DivAGradUPDE(space3, 2, f=f_rhs, g=f_bc, A=A)

ritz2 = NaturalGradientDeepRitzElliptic(
    pde3,
    bc_type="strong",
    functional_post_processing=functional_post_processing,
)

ritz2.solve(epochs=200, n_collocation=40000, verbose=False)
[3]:
plot_abstract_approx_spaces(
    (
        pinns.space,
        ritz.space,
        ritz2.space,
    ),
    (domain_x,),
    (domain_mu,),
    loss=(
        pinns.losses,
        ritz.losses,
        ritz2.losses,
    ),
    error=exact_sol,
    draw_contours=True,
    n_drawn_contours=20,
    parameters_values="mean",
    title="Solving $-\mu\\nabla.(A\\nabla u)=8\mu^2\pi^2\left(\sin(2\pi x)\sin(2\pi y)-\cos(2\pi x)\cos(2\pi y)\\right)$",
    titles=(
        "PINN with ENG preconditioning",
        "RITZ with no preconditioning",
        "RITZ with ENG preconditioning",
    ),
)
plt.show()
../../_images/example_notebooks_deep_ritz_div_a_grad_u_2d_3_0.png
[ ]: