Deep Ritz: div A grad u 2DΒΆ
Solves a 2D elliptic PDE with Dirichlet BCs using the Deep Ritz method and PINNs.
\[\begin{split}\left\{\begin{array}{rl}-\mu \nabla \cdot (A \nabla u) & = f \text{ in } \Omega \times M \\
u & = g \text{ in } \partial \Omega \times M\end{array}\right.\end{split}\]
where \(x = (x_1, x_2) \in \Omega = (0, 1) \times (0, 1)\), \(A = [[1, 1], [1, 1]]\),
\(f\) such that \(u(x_1, x_2, \mu) = \mu \sin(2\pi x_1) \sin(2\pi x_2)\), \(g = 0\) and \(\mu \in M = [0.5, 1]\).
Boundary conditions are enforced strongly.
[1]:
import matplotlib.pyplot as plt
import torch
from scimba_torch.approximation_space.nn_space import NNxSpace
from scimba_torch.domain.meshless_domain.domain_2d import Square2D
from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler
from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler
from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP
from scimba_torch.numerical_solvers.elliptic_pde.deep_ritz import (
DeepRitzElliptic,
NaturalGradientDeepRitzElliptic,
)
from scimba_torch.numerical_solvers.elliptic_pde.pinns import (
NaturalGradientPinnsElliptic,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
DivAGradUPDE,
)
from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces
from scimba_torch.utils.scimba_tensors import LabelTensor
def exact_sol(x: LabelTensor, mu: LabelTensor):
x1, x2 = x.get_components()
mu1 = mu.get_components()
return mu1 * torch.sin(2.0 * torch.pi * x1) * torch.sin(2.0 * torch.pi * x2)
def f_rhs(x: LabelTensor, mu: LabelTensor):
x1, x2 = x.get_components()
mu1 = mu.get_components()
return (
8.0
* mu1
* mu1
* (torch.pi**2)
* (
torch.sin(2.0 * torch.pi * x1) * torch.sin(2.0 * torch.pi * x2)
- torch.cos(2.0 * torch.pi * x1) * torch.cos(2.0 * torch.pi * x2)
)
)
def f_bc(x: LabelTensor, mu: LabelTensor):
x1, _ = x.get_components()
# mu1 = mu.get_components()
return torch.zeros_like(x1)
def A(x: torch.Tensor) -> torch.Tensor: # noqa: N802
return torch.ones(
x.shape[0],
2,
2,
dtype=torch.get_default_dtype(),
device=torch.get_default_device(),
)
domain_x = Square2D([(0.0, 1), (0.0, 1)], is_main_domain=True)
# domain_mu= []
domain_mu = [[0.5, 1.0]]
sampler = TensorizedSampler(
[DomainSampler(domain_x), UniformParametricSampler(domain_mu)]
)
def post_processing(inputs: torch.Tensor, x: LabelTensor, mu: LabelTensor):
x1, x2 = x.get_components()
# mu1 = mu.get_components()
return inputs * x1 * (1.0 - x1) * x2 * (1.0 - x2)
def functional_post_processing(
func, x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor
) -> torch.Tensor:
phi = x[0] * (1.0 - x[0]) * x[1] * (1.0 - x[1])
return func(x, mu, theta) * phi
[2]:
space = NNxSpace(
1,
1,
GenericMLP,
domain_x,
sampler,
layer_sizes=[64],
post_processing=post_processing,
)
pde = DivAGradUPDE(space, 2, f=f_rhs, g=f_bc, A=A)
pinns = NaturalGradientPinnsElliptic(
pde,
bc_type="strong",
functional_post_processing=functional_post_processing,
)
pinns.solve(epochs=200, n_collocation=900, verbose=False)
space2 = NNxSpace(
1,
1,
GenericMLP,
domain_x,
sampler,
layer_sizes=[64],
post_processing=post_processing,
)
pde2 = DivAGradUPDE(space2, 2, f=f_rhs, g=f_bc, A=A)
ritz = DeepRitzElliptic(
pde2,
bc_type="strong",
)
ritz.solve(epochs=3000, n_collocation=40000, verbose=False)
space3 = NNxSpace(
1,
1,
GenericMLP,
domain_x,
sampler,
layer_sizes=[64],
post_processing=post_processing,
)
pde3 = DivAGradUPDE(space3, 2, f=f_rhs, g=f_bc, A=A)
ritz2 = NaturalGradientDeepRitzElliptic(
pde3,
bc_type="strong",
functional_post_processing=functional_post_processing,
)
ritz2.solve(epochs=200, n_collocation=40000, verbose=False)
[3]:
plot_abstract_approx_spaces(
(
pinns.space,
ritz.space,
ritz2.space,
),
(domain_x,),
(domain_mu,),
loss=(
pinns.losses,
ritz.losses,
ritz2.losses,
),
error=exact_sol,
draw_contours=True,
n_drawn_contours=20,
parameters_values="mean",
title="Solving $-\mu\\nabla.(A\\nabla u)=8\mu^2\pi^2\left(\sin(2\pi x)\sin(2\pi y)-\cos(2\pi x)\cos(2\pi y)\\right)$",
titles=(
"PINN with ENG preconditioning",
"RITZ with no preconditioning",
"RITZ with ENG preconditioning",
),
)
plt.show()
[ ]: