r"""Learns the exact solution of a linearized Euler equation in 1D. The goal is to learn a function from :math:`\mathbb{R} \times \mathbb{R} \times \mathbb{R} \to \mathbb{R}^2`. This example compares three projection methods: Adam, SS-BFGS and natural gradient preconditioning. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import ( Projector, ) from scimba_jax.physical_models.function_approximator.function_approximator import ( FunctionApproximator, ) from scimba_jax.plots.plots_nd import ( plot_abstract_approx_spaces, ) N_EPOCHS = 50 N_EPOCHS_ADAM = 2000 N_EPOCHS_SSBFGS = 2000 N_COLLOC = 900 def exact_solution(t, x, mu): D = 0.02 coeff = 1 / (4 * jnp.pi * D) ** 0.5 p_plus_u = coeff * jnp.exp(-((x - t - 1) ** 2) / (4 * D)) p_minus_u = coeff * jnp.exp(-((x + t - 1) ** 2) / (4 * D)) p = (p_plus_u + p_minus_u) / 2 u = (p_plus_u - p_minus_u) / 2 return jnp.concatenate((p, u), axis=-1) dx = Segment1D((-1.0, 3.0), is_main_domain=True) domain_t = (0.0, 0.5) domain_mu = [] sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), DomainSampler(dx), UniformParametricSampler(domain_mu), ], model_type="t_x_mu", bc=False, ic=False, ) key = jax.random.PRNGKey(0) print("\n@@@@@@@@@@@@@@@ with Adam @@@@@@@@@@@@@@@@@@@@@@") nn = MLP(in_size=2, out_size=2, hidden_sizes=[32, 32], key=key) space = ApproximationSpace({"x": 1}, [(nn, "vec", 2)], model_type="t_x_mu") model = FunctionApproximator( main_domain=dx, size=2, model_type="t_x_mu", f_rhs=lambda *args: exact_solution(*args), ) projector = Projector(model, space, sampler, optimizer="Adam") start = timeit.default_timer() new_loss, nspace, key, loss_history = projector.project( space, key, N_EPOCHS_ADAM, N_COLLOC ) projector.losses.loss_history = loss_history projector.best_loss = new_loss projector.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS_ADAM, end - start) print("\n@@@@@@@@@@@@@@@ with SS-BFGS @@@@@@@@@@@@@@@@@@@@@@") nn = MLP(in_size=2, out_size=2, hidden_sizes=[32, 32], key=key) space2 = ApproximationSpace({"x": 1}, [(nn, "vec", 2)], model_type="t_x_mu") projector2 = Projector(model, space2, sampler, optimizer="SS-BFGS") start = timeit.default_timer() new_loss, nspace, key, loss_history = projector2.project( space2, key, N_EPOCHS_SSBFGS, N_COLLOC ) projector2.losses.loss_history = loss_history projector2.best_loss = new_loss projector2.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS_ADAM, end - start) # projector.plot(exact_sol=exact_solution, error=exact_solution) print("@@@@@@@@@@@@@@@ with ENG @@@@@@@@@@@@@@@@@@@@@@") nn = MLP(in_size=2, out_size=2, hidden_sizes=[32, 32], key=key) space3 = ApproximationSpace({"x": 1}, [(nn, "vec", 2)], model_type="t_x_mu") projector3 = Projector(model, space3, sampler) start = timeit.default_timer() new_loss, nspace, key, loss_history = projector3.project( space3, key, N_EPOCHS, N_COLLOC ) projector3.losses.loss_history = loss_history projector3.best_loss = new_loss projector3.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS, end - start) plot_abstract_approx_spaces( (projector.space, projector2.space, projector3.space), dx, domain_mu, domain_t, loss=(projector.losses, projector2.losses, projector3.losses), solution=exact_solution, error=exact_solution, ) plt.show()