r"""Solves the 1D ODE: Malthus equation: .. math:: u'(t) & = u(t) for t in \Omega where :math:`\Omega = [t_0, t_1]` and :math:`u(t_0) = a`. The initial condition is enforced weakly. The neural network is a simple MLP (Multilayer Perceptron). The optimization is done using either Natural Gradient Descent or SS-BFGS. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.model_class.funcparam_vectorial import ( ParamScalarFunction, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.abstract_physical_model import ( AbstractPhysicalModel, ) from scimba_jax.physical_models.abstract_residuals import ( NDARRAYS_FUNC_TYPE, PARAM_FUNC_TYPE, InitialResidual, InteriorResidual, ) # from scimba_jax.physical_models.initial_residuals import ( # ProjectionResidual, # ) N_COLLOC = 100 N_IC_COLLOC = 1 N_EPOCHS_ENG = 150 N_EPOCHS_SSBFGS = 1000 class InitialValueResidual(InitialResidual): def __init__( self, time_value: float = 0.0, value: float = 0.0, ): super().__init__( domain=None, time_domain=(time_value,), size=1, model_type="t", f_rhs=None, ) self.value = value def construct_residual(self, *vars: PARAM_FUNC_TYPE) -> PARAM_FUNC_TYPE: rho = vars[0] assert isinstance(rho, ParamScalarFunction) return rho.set_t_0(self.time_domain[0]) - self.value class MalthusResidual(InteriorResidual): def __init__( self, time_domain: tuple[float, ...], f_rhs: NDARRAYS_FUNC_TYPE | None = None ): super().__init__( domain=None, size=1, model_type="t", f_rhs=f_rhs, time_domain=time_domain, ) def construct_residual(self, *vars): u = vars[0] assert isinstance(u, ParamScalarFunction) return u.d_t() - u class MalthusEquation(AbstractPhysicalModel): def __init__( self, time_domain: tuple[float, float], f_rhs: NDARRAYS_FUNC_TYPE | None = None, ic: str = "weak", ic_t: float = 0.0, ic_v: float = 1.0, ): super().__init__(main_domain=None, time_domain=time_domain) self.physical_residuals = { "interior": MalthusResidual(time_domain=self.time_domain, f_rhs=f_rhs), } if ic == "weak": self.physical_residuals["ic interior"] = InitialValueResidual( time_value=ic_t, value=ic_v ) def exact_solution(t: jnp.ndarray, tv: float, v: float) -> jnp.ndarray: return v * jnp.exp(t - tv) def plot(pinn, exact_solution, time_domain=(0.0, 1.0), n_visu=128): t = jnp.linspace(time_domain[0], time_domain[1], n_visu)[..., None] u_pred = pinn.evaluate(t) u_exact = exact_solution(t) abs_error = jnp.abs(u_pred - u_exact) l2_error = jnp.sqrt(jnp.sum(abs_error**2)) / n_visu fig, axes = plt.subplots(1, 3, figsize=(15, 4)) axes[1].plot(t, u_pred, label="approximate sol.") axes[1].plot(t, u_exact, linestyle=":", label="exact sol.") axes[1].set_title("Exact and approximated solutions.") axes[2].plot(t, abs_error, label="absolute error") axes[2].set_title("L2 error: %.2e" % l2_error) axes[1].legend() axes[2].legend() pinn.losses.plot(axes[0]) plt.show() def solve_malthus_model( domain_t=(0.0, 1.0), t0=0.0, t0_val=1.0, optimizer="ENG", n_epochs=150, n_colloc=150, layers_sizes=[8, 8], verbose=False, plot_res=True, ): sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), ], model_type="t", bc=False, ic=True, ) key = jax.random.PRNGKey(0) nn = MLP(in_size=1, out_size=1, hidden_sizes=layers_sizes, key=key) space = ApproximationSpace({"t": 1}, [(nn, "scalar", None)], model_type="t") model = MalthusEquation(domain_t, ic_t=t0, ic_v=t0_val) pinn = Projector(model, space, sampler, optimizer=optimizer) if verbose: key, sample_dict = sampler.sample(key, N_COLLOC) loss = pinn.evaluate_loss(space, sample_dict) print("initial loss: ", loss) losses = pinn.evaluate_losses(space, sample_dict) print("initial losses: ", losses) print("\n\n") print("@@@@@@@@@@@@@@@ train with %s @@@@@@@@@@@@@@@@@@@@@" % optimizer) start = timeit.default_timer() key, pinn = pinn.project(key, space, n_epochs, n_colloc, n_ic=1) end = timeit.default_timer() if verbose: print("best loss: ", pinn.best_loss) print("time for %d epochs: " % n_epochs, end - start) if plot_res: plot(pinn, lambda t: exact_solution(t, t0, t0_val), domain_t) if __name__ == "__main__": solve_malthus_model( domain_t=(0.0, 1.0), t0=0.0, t0_val=1.0, optimizer="ENG", n_epochs=150, n_colloc=150, layers_sizes=[8, 8], verbose=False, plot_res=True, ) solve_malthus_model( domain_t=(0.0, 1.0), t0=1.0, t0_val=2.0, optimizer="SS-BFGS", n_epochs=500, n_colloc=150, layers_sizes=[8, 8], verbose=False, plot_res=True, )