r"""Inverse problem: identify ν in -Δu + νu = f from PDE + data. True solution: u(x,y) = sin(2π x) sin(2π y), domain [0,1]². With -Δu = 8π² u the equation gives: f(x,y) = (8π² + ν_true) sin(2π x) sin(2π y), ν_true = 2. ν is added as a second variable in the ApproximationSpace via a ``ScalarParam`` (single learnable bias, ignores spatial input). The standard Projector (ENG) optimises both the MLP weights and ν jointly. Two loss terms: 1. Physics: -Δu + ν u - f = 0 at collocation points. 2. Data: u_θ(xᵢ) = u_true(xᵢ) at N_DATA observation points. """ import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DataSampler, DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.networks.scalar_param import ScalarParam from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.abstract_physical_model import AbstractPhysicalModel from scimba_jax.physical_models.abstract_residuals import ( NDARRAYS_FUNC_TYPE, PARAM_FUNC_TYPE, InteriorResidual, ) from scimba_jax.physical_models.boundary_residuals import DirichletResidual from scimba_jax.physical_models.data_residuals import CollocDataResidual jax.config.update("jax_enable_x64", True) # ────────────────────────────────────────────────────────────────────────────── # Problem parameters # ────────────────────────────────────────────────────────────────────────────── NU_TRUE = 2.0 NU_INIT = 1.0 N_COLLOC = 2000 N_DATA = 300 N_EPOCHS = 200 DOMAIN = Square2D([(0.0, 1.0), (0.0, 1.0)], is_main_domain=True) def u_true(xy: jnp.ndarray) -> jnp.ndarray: x, y = xy[0], xy[1] return jnp.sin(2.0 * jnp.pi * x) * jnp.sin(2.0 * jnp.pi * y) def f_rhs(xy: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray: return jnp.array([(8.0 * jnp.pi**2 + NU_TRUE) * u_true(xy)]) # ────────────────────────────────────────────────────────────────────────────── # Residual: -Δu + ν u = f # vars[0] = u (MLP), vars[1] = ν (ScalarParam) # ────────────────────────────────────────────────────────────────────────────── class HelmholtzInverseResidual(InteriorResidual): def construct_residual(self, *vars: PARAM_FUNC_TYPE) -> PARAM_FUNC_TYPE: u = vars[0] nu = vars[1] return -u.laplacian_x() + nu * u # ────────────────────────────────────────────────────────────────────────────── # Physical model # ────────────────────────────────────────────────────────────────────────────── class HelmholtzInverse(AbstractPhysicalModel): def __init__(self, domain: Square2D, f_rhs: NDARRAYS_FUNC_TYPE): super().__init__(main_domain=domain) self.physical_residuals = { domain.get_label(): HelmholtzInverseResidual( domain=domain, size=1, f_rhs=f_rhs ) } for label in self.boundaries: self.physical_residuals[label] = DirichletResidual( domain=self.boundaries[label], size=1 ) # ────────────────────────────────────────────────────────────────────────────── # Setup # ────────────────────────────────────────────────────────────────────────────── key = jax.random.PRNGKey(0) # Data observations key, key_data = jax.random.split(key) x_data = jax.random.uniform(key_data, (N_DATA, 2)) mu_data = jnp.zeros((N_DATA, 0)) y_data = jax.vmap(u_true)(x_data)[:, None] data_sampler = DataSampler((x_data, mu_data, y_data)) sampler = TensorizedSampler( [DomainSampler(DOMAIN), UniformParametricSampler([])], bc=True, data_samplers={"data": data_sampler}, ) # Space: vars[0] = u (MLP), vars[1] = ν (ScalarParam) key, k_u = jax.random.split(key) nn_u = MLP(in_size=2, out_size=1, hidden_sizes=[20, 20, 20], key=k_u) nn_nu = ScalarParam(init=NU_INIT) space = ApproximationSpace( {"x": 2}, [(nn_u, "scalar", None), (nn_nu, "scalar", None)], model_type="x_mu", ) model = HelmholtzInverse(DOMAIN, f_rhs=f_rhs) model.add_data_residual("data", CollocDataResidual(size=1, model_type="x_mu")) # ────────────────────────────────────────────────────────────────────────────── # Training with standard Projector (ENG) # ────────────────────────────────────────────────────────────────────────────── pinn = Projector( model, space, sampler, weights={ DOMAIN.get_label(): [1.0], **{label: [10.0] for label in model.boundaries}, "data": [50.0], }, matrix_regularization=1e-6, ) key, pinn = pinn.project(key, space, N_EPOCHS, N_COLLOC, n_bc_colloc=500, n_ic_colloc=0) best_space = pinn.space loss_history = pinn.losses.losses_history # ────────────────────────────────────────────────────────────────────────────── # Results # ────────────────────────────────────────────────────────────────────────────── vars_best = best_space.create_variables() nu_fn = vars_best[1] nu_recovered = float(nu_fn(best_space, jnp.array([0.5, 0.5]), jnp.zeros((0,)))[0]) print(f"\nν true = {NU_TRUE}") print(f"ν initial = {NU_INIT}") print(f"ν recovered = {nu_recovered:.4f}") print(f"Error = {abs(nu_recovered - NU_TRUE):.2e}") # ── Plots ───────────────────────────────────────────────────────────────────── u_fn_best = vars_best[0] N_PLOT = 80 x_lin = jnp.linspace(0.0, 1.0, N_PLOT) X, Y = jnp.meshgrid(x_lin, x_lin) xy_grid = jnp.stack([X.ravel(), Y.ravel()], axis=1) mu_grid = jnp.zeros((N_PLOT**2, 0)) u_pred = jax.vmap(u_fn_best, in_axes=(None, 0, 0))(best_space, xy_grid, mu_grid) u_pred = u_pred.reshape(N_PLOT, N_PLOT) u_exact = jax.vmap(u_true)(xy_grid).reshape(N_PLOT, N_PLOT) u_err = jnp.abs(u_pred - u_exact) fig, axes = plt.subplots(1, 3, figsize=(13, 4)) kw = dict(cmap="turbo", origin="lower", extent=[0, 1, 0, 1]) im = axes[0].imshow(u_exact, **kw) axes[0].set_title("u exact") plt.colorbar(im, ax=axes[0]) im = axes[1].imshow(u_pred, **kw) axes[1].set_title(f"u PINN (ν={nu_recovered:.3f})") plt.colorbar(im, ax=axes[1]) im = axes[2].imshow(u_err, **kw) axes[2].set_title("|error|") plt.colorbar(im, ax=axes[2]) plt.suptitle(f"Helmholtz inverse — ν={nu_recovered:.4f} (true={NU_TRUE})") plt.tight_layout() plt.savefig("helmholtz_inverse_nu.png", dpi=120) plt.show() fig2, ax = plt.subplots(figsize=(7, 4)) ax.semilogy(loss_history["total"], label="total") for k in loss_history: if k != "total": ax.semilogy(loss_history[k], label=k, alpha=0.6) ax.set_xlabel("epoch") ax.set_ylabel("loss") ax.legend(fontsize=8) ax.set_title("Loss history — Helmholtz inverse") plt.tight_layout() plt.savefig("helmholtz_inverse_loss.png", dpi=120) plt.show()