r"""Solves the advection of a 2D parametric bump function on a disk. ..math:: \partial_t u + a \cdot \nabla u = 0 where :math:`u: \mathbb{R}^2 \times (0, T) \times \mathbb{R}^2 \to \mathbb{R}` is the unknown function, depending on space, time, and two parameters (the position and the variance of the initial bump). The equation is solved using the neural semi-Lagrangian scheme, with either a classical Adam optimizer, or a natural gradient preconditioning. """ # %% import matplotlib.pyplot as plt import torch from scimba_torch import PI from scimba_torch.approximation_space.nn_space import NNxSpace from scimba_torch.domain.meshless_domain.domain_2d import Disk2D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.collocation_projector import ( AbstractNonlinearProjector, ) from scimba_torch.numerical_solvers.temporal_pde.neural_semilagrangian import ( Characteristic, NeuralSemiLagrangian, ) from scimba_torch.numerical_solvers.temporal_pde.time_discrete import ( TimeDiscreteCollocationProjector, TimeDiscreteNaturalGradientProjector, ) from scimba_torch.optimizers.optimizers_data import OptimizerData from scimba_torch.physical_models.temporal_pde.advection_diffusion_equation import ( AdvectionReactionDiffusionDirichletStrongForm, ) from scimba_torch.plots.plot_time_discrete_scheme import plot_time_discrete_scheme from scimba_torch.utils.scimba_tensors import LabelTensor def f_ini(x, mu): return exact(LabelTensor(torch.zeros((x.shape[0], 1))), x, mu) def exact(t, x, mu): x1, x2 = x.get_components() t_ = t.get_components() v = 0.1 c = 0.3 x1_t = torch.cos(2 * PI * t_) x2_t = torch.sin(2 * PI * t_) return 1 + torch.exp(-0.5 / v**2 * ((x1 - c * x1_t) ** 2 + (x2 - c * x2_t) ** 2)) def exact_foot( t: torch.Tensor, x: LabelTensor, mu: LabelTensor, dt: float ) -> torch.Tensor: x1, x2 = x.get_components() arg = 2 * PI * t c_n = torch.cos(arg) s_n = torch.sin(arg) k1 = x1 * s_n + x2 * c_n k2 = x1 * c_n - x2 * s_n arg_np1 = 2 * PI * (t + dt) c_np1 = torch.cos(arg_np1) s_np1 = torch.sin(arg_np1) x1_np1 = k1 * s_np1 + k2 * c_np1 x2_np1 = -k2 * s_np1 + k1 * c_np1 return torch.cat((x1_np1, x2_np1), dim=1) def make_projector( pde, projection_type: str = "natural_gradient" ) -> AbstractNonlinearProjector: if projection_type == "natural_gradient": return TimeDiscreteNaturalGradientProjector(pde, rhs=f_ini) else: opt_1 = { "name": "adam", "optimizer_args": {"lr": 2.5e-2, "betas": (0.9, 0.999)}, } opt_2 = { "name": "lbfgs", "switch_at_epoch_ratio": 0.9, "optimizer_args": { "history_size": 50, "max_iter": 50, "tolerance_grad": 1e-11, "tolerance_change": 1e-9, }, } opt = OptimizerData(opt_1, opt_2) return TimeDiscreteCollocationProjector(pde, optimizers=opt) def solve_with_neural_sl( T: float, dt: float, projection_type: str = "natural_gradient" ) -> NeuralSemiLagrangian: torch.random.manual_seed(0) domain_x = Disk2D((0, 0), 1, is_main_domain=True) domain_mu = [] sampler = TensorizedSampler( [ DomainSampler(domain_x), UniformParametricSampler(domain_mu), ] ) def a(t, x, mu): x1, x2 = x.get_components() return torch.cat((-2 * PI * x2, 2 * PI * x1), dim=1) # TODO: implement PeriodicMLP (see branch develop_sl of old scimba) space = NNxSpace(1, 0, GenericMLP, domain_x, sampler, layer_sizes=[8, 16, 8]) pde = AdvectionReactionDiffusionDirichletStrongForm( space, a=a, u0=f_ini, constant_advection=False, zero_diffusion=True ) characteristic = Characteristic(pde, exact_foot=exact_foot) projector = make_projector(pde, projection_type) scheme = NeuralSemiLagrangian(characteristic, projector) scheme.initialization(epochs=500, n_collocation=10_000, verbose=True) scheme.solve(dt=dt, final_time=T, epochs=750, n_collocation=10_000, verbose=True) return scheme # %% if __name__ == "__main__": scheme = solve_with_neural_sl(0.5, 0.25, projection_type="natural_gradient") plot_time_discrete_scheme( scheme, solution=exact, error=exact, ) plt.show() # %%