r"""Learns the exact solution of a linearized Euler equation in 1D. The goal is to learn a function from :math:`\mathbb{R} \times \mathbb{R} \times \mathbb{R} \to \mathbb{R}^2`. This example compares three projection methods: Adam, SS-BFGS and natural gradient preconditioning. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import ( Projector, ) from scimba_jax.physical_models.function_approximator.function_approximator import ( FunctionApproximator, ) from scimba_jax.plots.plots_nd import ( plot_abstract_approx_spaces, ) N_EPOCHS = 50 N_COLLOC = 900 def exact_solution(t, x): D = 0.02 coeff = 1 / (4 * jnp.pi * D) ** 0.5 p_plus_u = coeff * jnp.exp(-((x - t - 1) ** 2) / (4 * D)) p_minus_u = coeff * jnp.exp(-((x + t - 1) ** 2) / (4 * D)) p = (p_plus_u + p_minus_u) / 2 u = (p_plus_u - p_minus_u) / 2 return jnp.concatenate((p, u), axis=-1) def create_domains(): dom_x = Segment1D((-1.0, 3.0), is_main_domain=True) dom_t = (0.0, 0.5) return dom_x, dom_t def create_and_train_proj(embedding=None, **kwargs) -> Projector: dom_x, dom_t = create_domains() sampler = TensorizedSampler( [ UniformTimeSampler(dom_t), DomainSampler(dom_x), ], model_type="t_x", bc=False, ic=False, ) key = jax.random.PRNGKey(0) nn = MLP( in_size=2, out_size=2, hidden_sizes=[12, 12], key=key, embedding=embedding, **kwargs, ) model = FunctionApproximator( main_domain=dom_x, size=2, model_type="t_x", f_rhs=lambda *args: exact_solution(*args), time_domain=dom_t, ) space = ApproximationSpace({"x": 1}, [(nn, "vec", 2)], model_type="t_x") projector = Projector(model, space, sampler) start = timeit.default_timer() key, projector = projector.project(key, space, N_EPOCHS, N_COLLOC) end = timeit.default_timer() print("best loss: ", projector.best_loss) print("time for %d epochs: " % N_EPOCHS, end - start) return projector dom_x, dom_t = create_domains() proj = create_and_train_proj(embedding=None) proj_periodic = create_and_train_proj( embedding="periodic", periods=[4.0], embedding_axes=[0] ) proj_fourier = create_and_train_proj( embedding="fourier", n_fourier_features=12, fourier_features_std=1.0 ) # Fourier features break periodicity # The following combination just showcases the flexibility of the framework, # but is not expected to perform better than just Fourier features alone proj_periodic_fourier = create_and_train_proj( embedding=("periodic", "fourier"), periods=[4.0], embedding_axes=[0], n_fourier_features=12, fourier_features_std=1.0, ) projs = [proj, proj_periodic, proj_fourier, proj_periodic_fourier] titles = [ "No embedding", "Periodic embedding", "Fourier features", "Periodic + Fourier", ] plot_abstract_approx_spaces( [p.space for p in projs], dom_x, time_domains=dom_t, loss=[p.losses for p in projs], solution=exact_solution, error=exact_solution, titles=titles, ) plt.show()