r"""Solves the linearized Euler equations in 1D using a PINN. .. math:: \partial_t p + \partial_x u & = f_1 in \Omega \times (0, T) \\ \partial_t u + \partial_x p & = f_2 in \Omega \times (0, T) \\ p & = g_1 on \partial \Omega \times (0, T) \\ u & = g_2 on \partial \Omega \times (0, T) \\ p & = p_0 on \Omega \times {0} \\ u & = u_0 on \Omega \times {0} where :math:`p: \partial \Omega \times (0, T) \to \mathbb{R}` and :math:`u: \partial \Omega \times (0, T) \to \mathbb{R}` are the unknown functions, :math:`\Omega \subset \mathbb{R}` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. Dirichlet boundary conditions are prescribed. The equation is solved on a segment domain; strong boundary condition and weak initial conditions are used. Two training strategies are compared: standard PINNs with SS-BFGS optimizer and PINNs with energy natural gradient preconditioning. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.temporal_pde.linearized_euler import LinearizedEuler from scimba_jax.plots.plots_nd import ( plot_abstract_approx_spaces, ) # sol exacte : u(x) = mu*sin(2*pi*x1)*sin(2*pi*x2) def exact_solution(t, x, mu): D = 0.02 coeff = 1 / (4 * jnp.pi * D) ** 0.5 p_plus_u = coeff * jnp.exp(-((x - t - 1) ** 2) / (4 * D)) p_minus_u = coeff * jnp.exp(-((x + t - 1) ** 2) / (4 * D)) p = (p_plus_u + p_minus_u) / 2 u = (p_plus_u - p_minus_u) / 2 return jnp.concatenate((p, u), axis=-1) def initial_solution(x, mu): return exact_solution(jnp.zeros_like(x), x, mu) def post_processing( approx: jnp.ndarray, t: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray ) -> jnp.ndarray: phi = (x - (-1.0)) * (x - 3.0) return approx * phi N_EPOCHS = 50 N_EPOCHS_SS_BFGS = 500 N_COLLOC = 1000 N_IC_COLLOC = 2000 t_min, t_max = 0.0, 0.5 domain_t = (t_min, t_max) domain_x = [(-1.0, 3.0)] dx = Segment1D(domain_x[0], is_main_domain=True) domain_mu = [] sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), DomainSampler(dx), UniformParametricSampler(domain_mu), ], model_type="t_x_mu", bc=False, ic=True, ) # create the model model = LinearizedEuler( main_domain=dx, time_domain=domain_t, bc="strong", ic="weak", f_ic_rhs=lambda *args: initial_solution(*args), ) # create the approximation space key = jax.random.PRNGKey(0) nn = MLP(in_size=2, out_size=2, hidden_sizes=[32, 32], key=key) space = ApproximationSpace({"x": 1}, [(nn, "vec", 2)], model_type="t_x_mu") print("\n\n") print("@@@@@@@@@@@@@@@ test ENG @@@@@@@@@@@@@@@@@@@@@@") pinn = Projector(model, space, sampler) key, sample_dict = sampler.sample(key, N_COLLOC, N_IC_COLLOC) loss0 = pinn.evaluate_loss(space, sample_dict) print("initial loss: ", loss0) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn.project( space, key, N_EPOCHS, N_COLLOC, N_IC_COLLOC ) pinn.losses.loss_history = loss_history pinn.best_loss = new_loss pinn.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS, end - start) print("\n\n") print("@@@@@@@@@@@@@@@ test SS-BFGS @@@@@@@@@@@@@@@@@@@@@@") key = jax.random.PRNGKey(0) nn = MLP(in_size=2, out_size=2, hidden_sizes=[40] * 4, key=key) space2 = ApproximationSpace({"x": 1}, [(nn, "vec", 2)], model_type="t_x_mu") pinn2 = Projector(model, space2, sampler, optimizer="SS-BFGS") loss0 = pinn2.evaluate_loss(space2, sample_dict) print("initial loss: ", loss0) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn2.project( space2, key, N_EPOCHS_SS_BFGS, N_COLLOC, N_IC_COLLOC ) pinn2.losses.loss_history = loss_history pinn2.best_loss = new_loss pinn2.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS_SS_BFGS, end - start) plot_abstract_approx_spaces( ( pinn.space, pinn2.space, ), dx, domain_mu, domain_t, time_values=[ t_max, ], loss=( pinn.losses, pinn2.losses, ), residual=( pinn.model, pinn2.model, ), solution=exact_solution, error=exact_solution, derivatives=["ux", "ut"], title="solving LinearizedEuler with TemporalPinns, strong boundary conditions", titles=("ENG preconditioning", "SS-BFGS"), ) plt.show()