r"""Solves the viscous Burgers advection equation in 1D using a discrete PINN. .. math:: \partial_t u + \partial_x \frac {u^2}{2} - \sigma \partial_{xx} u & = f in \Omega \times (0, T) \\ u & = g on \partial \Omega \times (0, T) \\ u & = u_0 on \Omega \times {0} where :math:`u: \partial \Omega \times (0, T) \to \mathbb{R}` is the unknown function, :math:`\Omega \subset \mathbb{R}` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. Homogeneous Dirichlet boundary conditions are prescribed, and the exact solution is :math:`u(t, x) = \sin(2 \pi x) e^{-t}`. The equation is solved on a segment domain; weak boundary conditions are used. The natural gradient optimizer is used; explicit and implicit time integrators are compared. """ import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.base import VolumetricDomain from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.linear_approximation.time_integrators.butcher_tableau import ( build_dirk_4_5_tableau, build_rk4_tableau, ) from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.discrete_pinns import ( DiscretePINN, ) from scimba_jax.physical_models.abstract_physical_model import ( PHYSICAL_RESIDUALS_TYPE, AbstractPhysicalModel, ) from scimba_jax.physical_models.abstract_residuals import ( PARAM_FUNC_TYPE, InteriorResidual, ) from scimba_jax.physical_models.boundary_residuals import DirichletResidual from scimba_jax.plots.plots_nd import plot_abstract_approx_spaces from scimba_jax.utils.typing_protocols import NDARRAYS_FUNC_TYPE N_COLLOC = 500 N_EPOCHS_INIT = 100 N_EPOCHS = 10 SIGMA = 1e-2 DOM_X = Segment1D((0.0, 1.0), is_main_domain=True) DOM_T = (0.0, 1.0) SAMPLER = TensorizedSampler([DomainSampler(DOM_X)], model_type="x") class ViscousBurgersResidual(InteriorResidual): def __init__( self, domain: VolumetricDomain, time_domain: tuple[float, float], f_rhs: NDARRAYS_FUNC_TYPE | None = None, model_type: str = "x", ): super().__init__( domain=domain, size=1, model_type=model_type, f_rhs=f_rhs, time_domain=time_domain, ) def construct_residual(self, *vars: PARAM_FUNC_TYPE) -> PARAM_FUNC_TYPE: rho = vars[0] rho2_x = ((rho**2) / 2.0).gradient_x() lap_x = rho.laplacian_x() return rho2_x - SIGMA * lap_x class ViscousBurgers(AbstractPhysicalModel): def __init__( self, main_domain: VolumetricDomain, time_domain: tuple[float, float], f_rhs: NDARRAYS_FUNC_TYPE | None = None, bc: str = "strong", f_bc_rhs: NDARRAYS_FUNC_TYPE | None = None, model_type: str = "x", ): super().__init__(main_domain=main_domain, time_domain=time_domain) self.physical_residuals: PHYSICAL_RESIDUALS_TYPE = { self.main_domain.get_label(): ViscousBurgersResidual( domain=main_domain, time_domain=time_domain, f_rhs=f_rhs, model_type=model_type, ), } if bc == "weak": for boundary in self.boundaries: self.physical_residuals[boundary] = DirichletResidual( domain=self.boundaries[boundary], size=1, model_type=model_type, f_rhs=f_bc_rhs, time_domain=time_domain, ) def exact_sol(t: jnp.ndarray, x: jnp.ndarray): return jnp.sin(2 * jnp.pi * x) * jnp.exp(-t) def f_init(x: jnp.ndarray): return exact_sol(jnp.zeros_like(x), x) def post_processing(approx: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: return approx * x * (1.0 - x) def viscous_burgers_rhs(t, *args) -> jnp.ndarray: x = args[0] # print(f"\n\n t in viscous_burgers_rhs: {t} \n") exp_neg_t = jnp.exp(-t) sin_x = jnp.sin(2 * jnp.pi * x) cos_x = jnp.cos(2 * jnp.pi * x) return ( exp_neg_t * sin_x * (2 * jnp.pi * (cos_x * exp_neg_t + 2 * jnp.pi * SIGMA) - 1.0) ) # create the discrete PINN parameters params = { "rk4": { "tableau": build_rk4_tableau(), "explicit_pde": ViscousBurgers(DOM_X, DOM_T, viscous_burgers_rhs), "implicit_pde": None, }, "dirk_4_5": { "tableau": build_dirk_4_5_tableau(), "explicit_pde": None, "implicit_pde": ViscousBurgers(DOM_X, DOM_T, viscous_burgers_rhs), }, } nt = 10 in_size = 1 out_size = 1 space = None spaces = [] for method, param in params.items(): butcher_tableau = param["tableau"] explicit_pde = param["explicit_pde"] implicit_pde = param["implicit_pde"] key = jax.random.PRNGKey(0) discrete_pinn = DiscretePINN( DOM_X, DOM_T, SAMPLER, out_size, nt, butcher_tableau, explicit_pde, implicit_pde ) if space is None: # train the initial condition only for the first tableau nn = MLP(in_size=in_size, out_size=out_size, hidden_sizes=[12] * 3, key=key) space = ApproximationSpace( {"x": 1}, [(nn, "scalar", None)], model_type="x", post_processing=post_processing, ) print("Initializing the discrete PINN...") key, space = discrete_pinn.initialize( key, space, f_init, N_EPOCHS_INIT, N_COLLOC ) print("Initializing the discrete PINN... Done\n") plot_abstract_approx_spaces( [space], DOM_X, solution=f_init, error=f_init, title="initial condition" ) print(f"\nSolving with the {method} method...") key, space_ = discrete_pinn.solve(key, space, N_EPOCHS, N_COLLOC) spaces.append(space_) print(f"\nSolving with the {method} method... Done\n") # %% plot_abstract_approx_spaces( spaces, DOM_X, solution=lambda x: exact_sol(jnp.ones_like(x) * DOM_T[-1], x), error=lambda x: exact_sol(jnp.ones_like(x) * DOM_T[-1], x), title=f"solution at final time t={DOM_T[-1]}", titles=[f"solution with {method} method" for method in params.keys()], ) plt.show() # %%