r"""Solves the viscous Burgers advection equation in 1D using a PINN. .. math:: \partial_t u + \partial_x \frac {u^2}{2} - \sigma \partial_{xx} u & = f in \Omega \times (0, T) \\ u & = g on \partial \Omega \times (0, T) \\ u & = u_0 on \Omega \times {0} where :math:`u: \partial \Omega \times (0, T) \to \mathbb{R}` is the unknown function, :math:`\Omega \subset \mathbb{R}` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. Dirichlet boundary conditions are prescribed. The equation is solved on a segment domain; weak boundary and initial conditions are used. Three training strategies are used: PINNs with Adam optimizer, PINNs with SS-BFGS optimizer, and PINNs with energy natural gradient preconditioning. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.temporal_pde.viscous_burgers import ViscousBurgers1D from scimba_jax.plots.plots_nd import ( plot_abstract_approx_space, plot_abstract_approx_spaces, ) N_COLLOC = 2000 N_BC_COLLOC = 2000 N_IC_COLLOC = 2000 N_EPOCHS = 200 N_EPOCHS_ADAM = 1000 N_EPOCHS_SSBFGS = 500 def f_rhs(t: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray, sigma: float) -> jnp.ndarray: exp_neg_t = jnp.exp(-t) sin_x = jnp.sin(2 * jnp.pi * x) cos_x = jnp.cos(2 * jnp.pi * x) return ( exp_neg_t * sin_x * (2 * jnp.pi * (cos_x * exp_neg_t + 2 * jnp.pi * sigma) - 1.0) ) def exact_sol(t: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray): return jnp.sin(2 * jnp.pi * x) * jnp.exp(-t) def f_init(x: jnp.ndarray, mu: jnp.ndarray): t = jnp.zeros_like(x) return exact_sol(t, x, mu) domain_t = (0.0, 1.0) domain_x = [(-1.0, 1.0)] domain_mu = [] dx = Segment1D(domain_x[0], is_main_domain=True) sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), DomainSampler(dx), UniformParametricSampler(domain_mu), ], model_type="t_x_mu", bc=True, ic=True, ) # create the model model = ViscousBurgers1D( main_domain=dx, time_domain=domain_t, bc="weak", ic="weak", f_rhs=f_rhs, f_ic_rhs=lambda *args: f_init(*args), sigma=1e-2 / jnp.pi, ) # create the approximation space key = jax.random.PRNGKey(42) nn = MLP(in_size=2, out_size=1, hidden_sizes=[30] * 3, key=key) space = ApproximationSpace( {"x": 1}, [(nn, "scalar", None)], model_type="t_x_mu", ) key = jax.random.PRNGKey(42) # create the pinn print("\n\n") print("@@@@@@@@@@@@@@@ train with Adam @@@@@@@@@@@@@@@@@@@@@@") # pinn = Projector(model, space, sampler, optimizer="SS-BFGS") pinn = Projector(model, space, sampler, optimizer="Adam", learning_rate=1e-4) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn.project( space, key, N_EPOCHS_SSBFGS, N_COLLOC, N_BC_COLLOC, N_IC_COLLOC ) pinn.losses.loss_history = loss_history pinn.best_loss = new_loss pinn.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS_ADAM, end - start) plot_abstract_approx_space( pinn.space, dx, domain_mu, domain_t, solution=exact_sol, error=exact_sol, time_values=[0.0, 0.5, 1.0], loss=pinn.losses, residual=pinn.model, ) plt.show() print("@@@@@@@@@@@@@@@ train with SS-BFGS @@@@@@@@@@@@@@@@@@@@@@") nn2 = MLP(in_size=2, out_size=1, hidden_sizes=[16, 32, 16], key=key) space2 = ApproximationSpace( {"x": 1}, [(nn2, "scalar", None)], model_type="t_x_mu", ) pinn2 = Projector(model, space2, sampler, optimizer="SS-BFGS") start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn2.project( space2, key, N_EPOCHS_SSBFGS, N_COLLOC, N_BC_COLLOC, N_IC_COLLOC ) pinn2.losses.loss_history = loss_history pinn2.best_loss = new_loss pinn2.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS_SSBFGS, end - start) print("@@@@@@@@@@@@@@@ train with ENG @@@@@@@@@@@@@@@@@@@@@@") nn3 = MLP(in_size=2, out_size=1, hidden_sizes=[16, 32, 16], key=key) space3 = ApproximationSpace( {"x": 1}, [(nn3, "scalar", None)], model_type="t_x_mu", ) pinn3 = Projector(model, space3, sampler) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn3.project( space3, key, N_EPOCHS, N_COLLOC, N_BC_COLLOC, N_IC_COLLOC ) pinn3.losses.loss_history = loss_history pinn3.best_loss = new_loss pinn3.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS, end - start) plot_abstract_approx_spaces( ( pinn.space, pinn2.space, pinn3.space, ), dx, domain_mu, domain_t, solution=exact_sol, error=exact_sol, loss=(pinn.losses, pinn2.losses, pinn3.losses), residual=(pinn.model, pinn2.model, pinn3.model), titles=("Adam", "SS-BFGS", "Energy Natural Gradient"), ) plt.show()