r"""Solves the heat equation in 1D using a PINN. .. math:: \partial_t u - \partial_{xx} u & = f in \Omega \times (0, T) \\ u & = g on \partial \Omega \times (0, T) \\ u & = u_0 on \Omega \times {0} where :math:`u: \partial \Omega \times (0, T) \to \mathbb{R}` is the unknown function, :math:`\Omega \subset \mathbb{R}` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. Dirichlet boundary conditions are prescribed, and the initial condition is smooth. The equation is solved on a segment domain; strong boundary conditions and weak initial conditions are used. Three training strategies are compared: standard PINNs, PINNs with energy natural gradient preconditioning and PINNs with Anagram preconditioning. """ # %% import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.nn_space import NNxtSpace from scimba_torch.domain.meshless_domain.domain_1d import Segment1D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.integration.monte_carlo_time import UniformTimeSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.temporal_pde.pinns import ( AnagramTemporalPinns, NaturalGradientTemporalPinns, TemporalPinns, ) from scimba_torch.optimizers.optimizers_data import OptimizerData from scimba_torch.physical_models.temporal_pde.heat_equation import ( HeatEquation1DDirichletStrongForm, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor # %% bc_weight = 10.0 ic_weight = 500.0 def exact(t, x, mu): x1 = x.get_components() t1 = t.get_components() return torch.exp(-(t1 * torch.pi**2) / 4.0) * torch.sin(torch.pi * x1) def f_rhs(w, t, x, mu): mu1 = mu.get_components() return 0 * mu1 def f_bc(w, t, x, n, mu): x1 = x.get_components() return 0 * x1 def f_ini(x, mu): t = LabelTensor(torch.zeros_like(x.x)) return exact(t, x, mu) domain_x = Segment1D((0, 1), is_main_domain=True) def post_processing( inputs: torch.Tensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ): x1 = x.get_components() return inputs * x1 * (1.0 - x1) def functional_post_processing(func, *args: torch.Tensor) -> torch.Tensor: # print("args[0].shape: ", args[0].shape) # print("args[1].shape: ", args[1].shape) return func(*args) * args[1][0] * (1.0 - args[1][0]) t_min, t_max = 0.0, 1.0 sampler = TensorizedSampler( [ UniformTimeSampler((t_min, t_max)), DomainSampler(domain_x), UniformParametricSampler([(0.25, 0.25 + 1e-5)]), ] ) # weak BC # space = NNxtSpace( 1, 1, GenericMLP, domain_x, sampler, layer_sizes=[20, 40, 20], post_processing=post_processing, ) pde = HeatEquation1DDirichletStrongForm(space, init=f_ini, f=f_rhs, g=f_bc) opt_1 = { "name": "adam", "optimizer_args": {"lr": 1.8e-2, "betas": (0.9, 0.999)}, } pinn = TemporalPinns( pde, bc_type="strong", ic_type="weak", optimizers=OptimizerData(opt_1), bc_weight=bc_weight, ic_weight=ic_weight, ) resume_solve = True if resume_solve or not pinn.load(__file__, "no_precond"): pinn.solve( epochs=1000, n_collocation=2000, n_bc_collocation=50, n_ic_collocation=1500, verbose=True, ) pinn.save(__file__, "no_precond") pinn.space.load_from_best_approx() space2 = NNxtSpace( 1, 1, GenericMLP, domain_x, sampler, layer_sizes=[64], post_processing=post_processing, ) pde2 = HeatEquation1DDirichletStrongForm(space2, init=f_ini, f=f_rhs, g=f_bc) pinn2 = NaturalGradientTemporalPinns( pde2, bc_type="strong", ic_type="weak", ic_weight=ic_weight, matrix_regularization=1e-6, functional_post_processing=functional_post_processing, ) resume_solve = True if resume_solve or not pinn2.load(__file__, "ENG"): pinn2.solve( epochs=200, n_collocation=900, n_ic_collocation=30, verbose=True, ) pinn2.save(__file__, "ENG") pinn2.space.load_from_best_approx() space3 = NNxtSpace( 1, 1, GenericMLP, domain_x, sampler, layer_sizes=[64], post_processing=post_processing, ) pde3 = HeatEquation1DDirichletStrongForm(space3, init=f_ini, f=f_rhs, g=f_bc) pinn3 = AnagramTemporalPinns( pde3, bc_type="strong", ic_type="weak", ic_weight=ic_weight, svd_threshold=5e-5, functional_post_processing=functional_post_processing, ) resume_solve = True if resume_solve or not pinn3.load(__file__, "Anagram"): pinn3.solve( epochs=200, n_collocation=900, n_ic_collocation=30, verbose=True, ) pinn3.save(__file__, "Anagram") pinn3.space.load_from_best_approx() plot_abstract_approx_spaces( (pinn.space, pinn2.space, pinn3.space), domain_x, ([(0.25, 0.25 + 1e-5)],), ((0.0, 1.0),), parameters_values=([0.25],), time_values=([1e-2], [1e-2], [1e-2]), loss=(pinn.losses, pinn2.losses, pinn3.losses), residual=(pde, pde2, pde3), solution=exact, error=exact, ) plt.show()