r"""Learns the initial condition of a transport equation in 1D. The goal is to learn a function from :math:`\mathbb{R} \times \mathbb{R} \times \mathbb{R} \to \mathbb{R}`. The network is a MLP and the wieghts are fit with Natural Gradient Descent. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.domains.meshless_domains.domains_2d import Circle2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, UniformVelocitySampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) # from scimba_jax.nonlinear_approximation.losses.loss_computation import make_grad_theta from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import ( Projector, ) from scimba_jax.physical_models.function_approximator.function_approximator import ( FunctionApproximator, ) from scimba_jax.plots.plots_nd import ( plot_abstract_approx_space, ) def exact(t, x, v, mu): v_x = v[..., 0] v_a = jnp.where(v_x <= 0, -1.0, 1.0) exp_1 = jnp.exp(-((x - v_a * t) ** 2 * 10)) exp_2 = jnp.exp(-((v_a - 1.0) ** 2 * 0.01)) return exp_1 * exp_2 exact_for_plot = jax.vmap(exact) N_EPOCHS = 50 N_EPOCHS_SSBFGS = 2000 N_COLLOC = 900 domain_t = (0.0, 0.1) domain_x = Segment1D((-1.0, 1.0), is_main_domain=True) domain_v = Circle2D((0, 0), 1.0) domain_mu = [] # create model and equation: 2D laplacian with one parameter key = jax.random.PRNGKey(0) nn = MLP(in_size=4, out_size=1, hidden_sizes=[32, 32], key=key) space = ApproximationSpace( {"x": 1, "v": 2}, [(nn, "scalar", None)], model_type="t_x_v_mu" ) sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), DomainSampler(domain_x), UniformVelocitySampler(domain_v), UniformParametricSampler(domain_mu), ], model_type="t_x_v_mu", ) key, sample_dict = sampler.sample(key, N_COLLOC) print("\n\n") print("@@@@@@@@@@@@@@@ test ENG @@@@@@@@@@@@@@@@@@@@@@") model = FunctionApproximator( domain_x, size=1, model_type="t_x_v_mu", f_rhs=lambda *args: exact(*args) ) projector = Projector(model, space, sampler) loss0 = projector.evaluate_loss(space, sample_dict) print("initial loss: ", loss0) start = timeit.default_timer() new_loss, nspace, key, loss_history = projector.project(space, key, N_EPOCHS, N_COLLOC) projector.losses.loss_history = loss_history projector.best_loss = new_loss projector.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS, end - start) plot_abstract_approx_space( projector.space, domain_x, domain_mu, domain_t, domain_v, solution=exact, error=exact, derivatives=["ux", "uv0", "uv1"], velocity_values=[0.0, jnp.pi], velocity_strs=["+1", "-1"], ) plt.show()