r"""Solves the heat equation in 1D using a PINN. .. math:: \partial_t u - \partial_{xx} u & = f in \Omega \times (0, T) \\ \partial_x u & = g on \partial \Omega \times (0, T) \\ u & = u_0 on \Omega \times {0} where :math:`u: \partial \Omega \times (0, T) \to \mathbb{R}` is the unknown function, :math:`\Omega \subset \mathbb{R}` is the spatial domain and :math:`(0, T) \subset \mathbb{R}` is the time domain. Neumann boundary conditions are prescribed, and the initial condition is a sine. The equation is solved on a segment domain; weak boundary and initial conditions are used. Two training strategies are compared: PINNs with energy natural gradient preconditioning and PINNS with SS-BFGS optimization. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_1d import Segment1D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.temporal_pde.heat_equations import HeatND from scimba_jax.plots.plots_nd import ( plot_abstract_approx_space, plot_abstract_approx_spaces, ) N_COLLOC = 900 N_BC_COLLOC = 2000 N_IC_COLLOC = 2000 N_EPOCHS = 50 N_EPOCHS_SSBFGS = 500 def exact_sol(t: jnp.ndarray, x: jnp.ndarray, mu: jnp.ndarray): return jnp.exp(-(t * jnp.pi**2)) * jnp.sin(jnp.pi * x) def f_init(x: jnp.ndarray, mu: jnp.ndarray): t = jnp.zeros_like(x) return exact_sol(t, x, mu) domain_t = (0.0, 1.0) domain_x = [(0.0, 1.0)] domain_mu = [] dx = Segment1D(domain_x[0], is_main_domain=True) sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), DomainSampler(dx), UniformParametricSampler(domain_mu), ], model_type="t_x_mu", bc=True, ic=True, ) # create the model model = HeatND( main_domain=dx, time_domain=domain_t, bc="weak", ic="weak", f_ic_rhs=lambda *args: f_init(*args), ) # create the approximation space key = jax.random.PRNGKey(0) nn = MLP(in_size=2, out_size=1, hidden_sizes=[16, 16], key=key) space = ApproximationSpace( {"x": 2}, [(nn, "scalar", None)], model_type="t_x_mu", ) # create the pinn print("\n\n") print("@@@@@@@@@@@@@@@ train with ENG @@@@@@@@@@@@@@@@@@@@@@") pinn = Projector(model, space, sampler) key, sample_dict = sampler.sample(key, N_COLLOC) loss0 = pinn.evaluate_loss(space, sample_dict) print("initial loss: ", loss0) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn.project( space, key, N_EPOCHS, N_COLLOC, N_BC_COLLOC, N_IC_COLLOC ) pinn.losses.loss_history = loss_history pinn.best_loss = new_loss pinn.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS, end - start) plot_abstract_approx_space( pinn.space, dx, domain_mu, domain_t, time_values=[0.0, 1e-3, 0.1], loss=pinn.losses, residual=pinn.model, solution=exact_sol, error=exact_sol, derivatives=["ux", "ut", "utx"], title="learning sol of 1D heat equation with TemporalPinns", ) plt.show() print("\n\n") print("@@@@@@@@@@@@@@@ train with SS-BFGS @@@@@@@@@@@@@@@@@@@@@@") nn = MLP(in_size=2, out_size=1, hidden_sizes=[16, 16], key=key) space2 = ApproximationSpace( {"x": 2}, [(nn, "scalar", None)], model_type="t_x_mu", ) pinn2 = Projector(model, space2, sampler, optimizer="SS-BFGS") print("initial loss: ", loss0) start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn2.project( space, key, N_EPOCHS_SSBFGS, N_COLLOC, N_BC_COLLOC, N_IC_COLLOC ) pinn2.losses.loss_history = loss_history pinn2.best_loss = new_loss pinn2.space = nspace end = timeit.default_timer() print("best loss: ", new_loss) print("time for %d epochs: " % N_EPOCHS_SSBFGS, end - start) plot_abstract_approx_spaces( (pinn.space, pinn2.space), dx, domain_mu, domain_t, # time_values=[0.0, 1e-3, 0.1], loss=(pinn.losses, pinn2.losses), residual=(pinn.model, pinn2.model), solution=exact_sol, error=exact_sol, derivatives=["ux", "ut", "utx"], title="learning sol of 1D heat equation with TemporalPinns", titles=("with ENG", "with SS-BFGS"), ) plt.show()