r"""Solves a 1D advection equation in a periodic domain using a discrete PINN. ..math:: \partial_t u + a \partial_x u = 0 The initial condition is a bump function, and the space domain is endowed with periodic boundary conditions. """ # %% import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace from scimba_torch.approximation_space.nn_space import NNxSpace from scimba_torch.domain.meshless_domain.domain_1d import Segment1D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.collocation_projector import ( NaturalGradientProjector, ) from scimba_torch.numerical_solvers.temporal_pde.discrete_pinn import DiscretePINN from scimba_torch.numerical_solvers.temporal_pde.time_discrete import ( TimeDiscreteNaturalGradientProjector, ) from scimba_torch.physical_models.temporal_pde.abstract_temporal_pde import ( FirstOrderTemporalPDE, ) from scimba_torch.plots.plot_time_discrete_scheme import plot_time_discrete_scheme from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor class Transport1DPeriodic(FirstOrderTemporalPDE): r"""Implementation of a 1D transport equation with Dirichlet boundary conditions. :param space: The approximation space for the problem. :type space: AbstractApproxSpace :param f: Callable representing the source term \( f(x, t, \mu) \). :type f: Callable :param g: Callable representing the Dirichlet boundary condition \( g(x, t, \mu) \). :type g: Callable :param kwargs: Additional keyword arguments. :type kwargs: dict """ def __init__( self, space: AbstractApproxSpace, **kwargs, ): super().__init__(space, linear=True, **kwargs) def space_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor: u = w.get_components() u_x = self.space.grad(u, x) return u_x def time_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor: return self.grad(w.get_components(), t) def rhs( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: return torch.zeros_like(w.w) def bc_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: # x_ = x.get_components() w0 = w.get_components() w_l = w.restrict_to_labels(w0, labels=[0]) w_r = w.restrict_to_labels(w0, labels=[1]) return torch.cat([w_l, w_r], dim=0) def bc_rhs( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: # x_ = x.get_components() w0 = w.get_components() w_l = w.restrict_to_labels(w0, labels=[0]) w_r = w.restrict_to_labels(w0, labels=[1]) return torch.cat([w_r, w_l], dim=0) def initial_condition(self, x: LabelTensor, mu: LabelTensor): return exact(LabelTensor(torch.zeros_like(x.x)), x, mu) def exact(t, x, mu): x_ = x.get_components() t_ = t.get_components() a = 1 v = 0.1 x_ = (x_ - a * t_) % 1.0 return 1 + torch.exp(-((x_ - 0.65) ** 2) / (2 * v**2)) # %% def solve_with_discrete_pinn(T: float, dt: float, scheme: str): torch.random.manual_seed(0) domain_x = Segment1D((0, 1), is_main_domain=True) domain_mu = [] sampler = TensorizedSampler( [ DomainSampler(domain_x), UniformParametricSampler(domain_mu), ] ) space = NNxSpace(1, 0, GenericMLP, domain_x, sampler, layer_sizes=[30, 30]) pde = Transport1DPeriodic(space) projector_init = NaturalGradientProjector( space, rhs=pde.initial_condition, ) projector = TimeDiscreteNaturalGradientProjector( pde, rhs=pde.initial_condition, bc_rhs=pde.bc_rhs, bc_weight=500.0, ) scheme = DiscretePINN( pde, projector=projector, projector_init=projector_init, scheme=scheme, bool_weak_bc=True, ) scheme.initialization(epochs=50, verbose=True, n_collocation=3000) scheme.projector.space.load_from_best_approx() scheme.solve(dt=dt, T=T, epochs=80, n_collocation=3000, verbose=True) return scheme # scheme2 # %% if __name__ == "__main__": # scheme = solve_with_discrete_pinn(0.03, 0.005, "euler_exp") scheme = solve_with_discrete_pinn(0.2, 0.01, "euler_exp") plot_time_discrete_scheme( scheme, solution=exact, error=exact, ) plt.show()