r"""Solves a 1D heat equation with pre-processing. The neural network is a simple MLP (Multilayer Perceptron). The initial conditions are handled strongly. The optimization is done using Natural Gradient Descent. """ # %% import warnings import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.nn_space import NNxtSpace from scimba_torch.domain.meshless_domain.domain_nd import BallND from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.integration.monte_carlo_time import UniformTimeSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.temporal_pde.pinns import ( NaturalGradientTemporalPinns, ) from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import ( StrongFormEllipticPDE, ) from scimba_torch.physical_models.temporal_pde.abstract_temporal_pde import ( FirstOrderTemporalPDE, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor # we use a BallND domain, which does not have a boundary condition # the command below ignores the warning about the lack of BC # indeed, we handle it strongly in the post-processing warnings.filterwarnings("ignore") torch.manual_seed(0) SIGMA = 1 RADIUS = 1 # %% def func_exact_one_x(t, x, mu): d = x.shape[0] denom = SIGMA**2 + 2 * t fac = (SIGMA**2 / denom) ** (d / 2) r2 = torch.sum(x**2, dim=-1, keepdim=True) return fac * torch.exp(-r2 / (2 * denom)) def func_exact(t, x, mu): d = x.shape[1] denom = SIGMA**2 + 2 * t fac = (SIGMA**2 / denom) ** (d / 2) r2 = torch.sum(x**2, dim=-1)[:, None] return fac * torch.exp(-r2 / (2 * denom)) def exact(t, x, mu): return func_exact(t.x, x.x, mu.x) def f_ini(x, mu): t = LabelTensor(torch.zeros(x.shape[0], 1)) return exact(t, x, mu) def f_bc_rhs(w, t, x, n, mu): return exact(t, x, mu) class LaplacianNDDirichletStrongFormNoParam(StrongFormEllipticPDE): def __init__(self, space, f, g, **kwargs): super().__init__( space, linear=True, residual_size=1, bc_residual_size=1, **kwargs ) self.f = f self.g = g def rhs(self, w, x, mu): return self.f(x, mu) def operator(self, w, x, mu): u = w.get_components() grad_u = torch.cat(tuple(self.grad(u, x)), dim=-1) space_dim = grad_u.shape[1] laplacian_u = [tuple(self.grad(grad_u[:, i], x))[i] for i in range(space_dim)] laplacian_u = torch.sum(torch.stack(laplacian_u, dim=-1), dim=-1) return -laplacian_u def bc_rhs(self, w, x, n, mu): return self.g(x, mu) def bc_operator(self, w, x, n, mu): return w.get_components() def zeros_rhs(w, t, x, mu, nb_func: int = 1): return torch.zeros(x.shape[0], nb_func) def zeros_bc_rhs(w, t, x, n, mu, nb_func: int = 1): return torch.zeros(x.shape[0], nb_func) class HeatEquationNDDirichletStrongFormNoParam(FirstOrderTemporalPDE): def __init__(self, space, init, f=zeros_rhs, g=zeros_bc_rhs, **kwargs): super().__init__(space, linear=True, **kwargs) self.space_component = LaplacianNDDirichletStrongFormNoParam(space, f, g) self.f = f self.g = g self.init = init self.ic_residual_size = 1 def space_operator(self, w, t, x, mu): return self.space_component.operator(w, x, mu) def bc_operator(self, w, t, x, n, mu): return self.space_component.bc_operator(w, x, n, mu) def rhs(self, w, t, x, mu: LabelTensor): return self.f(w, t, x, mu) def bc_rhs(self, w, t, x, n, mu): return self.g(w, t, x, n, mu) def initial_condition(self, x, mu): return self.init(x, mu) def functional_operator(self, func, t, x, mu, theta): time_op = torch.func.jacrev(func, 0)(t, x, mu, theta) grad_u = torch.func.jacrev(func, 1) grad_grad_u = torch.func.jacrev(grad_u, 1)(t, x, mu, theta).squeeze() laplacian_u = torch.einsum("ii", grad_grad_u) return time_op[0] - laplacian_u[None] def functional_operator_bc(self, func, t, x, n, mu, theta): return func(t, x, mu, theta) def functional_operator_ic(self, func, x, mu, theta): t = torch.zeros_like(x) return func(t, x, mu, theta) def post_processing( inputs: torch.Tensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ): x_ = x.get_components() t1 = t.get_components() x_max_tensor = torch.zeros_like(x.x) x_max_tensor[:, 0] = RADIUS x_max_tensor = LabelTensor(x_max_tensor) u_ini = f_ini(x, mu) u_bc = exact(t, x_max_tensor, mu) u_bc_ini = f_ini(x_max_tensor, mu) level_set = 1 - torch.sum( torch.cat([x_i**2 for x_i in x_], dim=-1) / RADIUS**2, dim=-1, keepdim=True ) return u_ini + u_bc - u_bc_ini + inputs * t1 * level_set def func_f_ini(x, mu): t = torch.zeros_like(x) return func_exact(t, x, mu) def functional_post_processing(func, t, x, mu, theta) -> torch.Tensor: u_ini = func_exact_one_x(0, x, mu) x_max_tensor = torch.zeros_like(x) x_max_tensor[0] = RADIUS u_bc = func_exact_one_x(t, x_max_tensor, mu) u_bc_ini = func_exact_one_x(0, x_max_tensor, mu) inputs = func(t, x, mu, theta) level_set = 1 - torch.sum(x**2, dim=-1, keepdim=True) / RADIUS**2 return u_ini + u_bc - u_bc_ini + inputs * t * level_set def pre_processing(t, x: LabelTensor, mu: LabelTensor): r2 = torch.sum(x.x**2, dim=-1, keepdim=True) return torch.cat([t.x, r2], dim=-1) def functional_pre_processing(*args): r2 = torch.sum(args[1] ** 2, dim=-1, keepdim=True) return torch.cat([args[0], r2], dim=-1) def solve(d: int, use_pre_processing: bool) -> NaturalGradientTemporalPinns: """Solves the problem for the heat equation in dimension d. Args: d: The dimension of the space domain. use_pre_processing: Whether to use pre-processing or not. """ domain_x = BallND(center=(0,) * d, radius=RADIUS, is_main_domain=True) t_min, t_max = 0.0, 1 / d sampler = TensorizedSampler( [ UniformTimeSampler((t_min, t_max)), DomainSampler(domain_x), UniformParametricSampler([]), ] ) space_args = (1, 0, GenericMLP, domain_x, sampler) if use_pre_processing: space = NNxtSpace( *space_args, layer_sizes=[16, 16], pre_processing=pre_processing, pre_processing_out_size=2, # t and r^2 post_processing=post_processing, ) else: space = NNxtSpace( *space_args, layer_sizes=[16, 16], post_processing=post_processing, ) pde = HeatEquationNDDirichletStrongFormNoParam(space, init=f_ini, g=f_bc_rhs) if use_pre_processing: pinn = NaturalGradientTemporalPinns( pde, ic_type="strong", bc_type="strong", matrix_regularization=1e-6, functional_pre_processing=functional_pre_processing, functional_post_processing=functional_post_processing, ) else: pinn = NaturalGradientTemporalPinns( pde, ic_type="strong", bc_type="strong", matrix_regularization=1e-6, functional_post_processing=functional_post_processing, ) pinn.solve(epochs=100, n_collocation=1000 * 2 ** (d - 1)) if d == 2: plot_abstract_approx_spaces( pinn.space, domain_x, [], (t_min, t_max), loss=pinn.losses, residual=pde, error=exact, ) plt.show() return pinn # %% def compute_errors(pinn, pinn_no_pp, n_test=25_000): sampler = pinn.pde.space.integrator t, x, mu = sampler.sample(n_test) y_exact = exact(t, x, mu) y_pp = pinn.evaluate(t, x, mu).w y_no_pp = pinn_no_pp.evaluate(t, x, mu).w err_pp = (y_pp - y_exact).abs() err_no_pp = (y_no_pp - y_exact).abs() return ( torch.norm(err_pp).item(), torch.norm(err_no_pp).item(), err_pp.max().item(), err_no_pp.max().item(), ) def make_error_table(results): header = ( f"{'d':>3} | " f"{'L2 (pp)':>10} | " f"{'L2 (no)':>10} | " f"{'Linf (pp)':>10} | " f"{'Linf (no)':>10}" ) sep = "-" * len(header) print("\nError Summary") print(sep) print(header) print(sep) for d, l2_pp, l2_no_pp, linf_pp, linf_no_pp in results: print( f"{d:3d} | " f"{l2_pp:10.2e} | " f"{l2_no_pp:10.2e} | " f"{linf_pp:10.2e} | " f"{linf_no_pp:10.2e}" ) print(sep) # %% if __name__ == "__main__": all_dims = [2, 3, 4, 5, 6] results = [] for d in all_dims: print(f"\nSolving dimension {d} (no pre-processing)...") pinn_no_pp = solve(d=d, use_pre_processing=False) print(f"Solving dimension {d} (with pre-processing)...") pinn_pp = solve(d=d, use_pre_processing=True) res = compute_errors(pinn_pp, pinn_no_pp) results.append([d, *res]) make_error_table(results) # %%