r"""Density-to-density transport via Monge-Ampère PINN. Source: :math:`f(x) = 1` (uniform on :math:`[0,1]^2`) Target (ring density): .. math:: g(y) = \frac{1}{Z_g}\Bigl(1 + 5\exp\!\bigl(-100\bigl((y_0-0.5)^2+(y_1-0.5)^2-0.09\bigr)\bigr)\Bigr) Network: ICNN :math:`u: [0,1]^2\to\mathbb{R}` (convex by construction). OT map (Brenier): :math:`T = \nabla u`. PDE: :math:`g(\nabla u)\,\det(\nabla^2 u) = f`. Option ``WITH_FLOW = True``: After MA training, a MLP T_flow is fitted to :math:`\nabla u`. Both grid images are plotted side by side. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.networks.icnn import ICNN from scimba_jax.nonlinear_approximation.networks.structure_preserving_nets.affine_ode_layers import ( AffineFlowLayer, ) from scimba_jax.nonlinear_approximation.networks.structure_preserving_nets.coupling_layers import ( CouplingLayer, ) from scimba_jax.nonlinear_approximation.networks.structure_preserving_nets.invertible_nn import ( InvertibleNet, ) from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.elliptic_pde.monge_ampere import MongeAmpere2D from scimba_jax.physical_models.function_approximator.function_approximator import ( FunctionApproximator, ) from scimba_jax.plots.plots_nd import plot_abstract_approx_space # ── Configuration ────────────────────────────────────────────────────────────── WITH_FLOW = True # True: fit T_flow ≈ ∇u after MA training and compare grids N_COLLOC = 8000 N_BC_COLLOC = 5000 N_EPOCHS = 400 N_FLOW_EPOCHS = 200 # epochs for fitting T_flow ≈ ∇u (only if WITH_FLOW) # ── Densities ────────────────────────────────────────────────────────────────── def f(x): """Uniform source density on [0,1]².""" return jnp.array([1.0]) def _g_unnorm(y): """Ring target density (non normalisée).""" return jnp.array( [ 1.0 + 5.0 * jnp.exp(-100.0 * jnp.abs((y[0] - 0.5) ** 2 + (y[1] - 0.5) ** 2 - 0.09)) ] ) _key_zg = jax.random.PRNGKey(42) _pts_zg = jax.random.uniform(_key_zg, (200_000, 2)) Z_g = float(jnp.mean(jax.vmap(lambda y: _g_unnorm(y))(_pts_zg))) def g(y): """Ring target density normalisée : ∫ g = 1 = ∫ f.""" return _g_unnorm(y) / Z_g # ── Domain ───────────────────────────────────────────────────────────────────── key = jax.random.PRNGKey(0) domain_mu: list = [] dx = Square2D([(0.0, 1.0), (0.0, 1.0)], is_main_domain=True) dx2 = Square2D([(-2.0, 2.0), (-2.0, 2.0)], is_main_domain=True) # ── Physical model ───────────────────────────────────────────────────────────── model = MongeAmpere2D(dx, f=f, g=g, model_type="x_mu", bc="weak", with_flow=WITH_FLOW) # ── Réseaux ──────────────────────────────────────────────────────────────────── key, subkey = jax.random.split(key) nn = ICNN( in_size=2, out_size=1, hidden_sizes=[16, 16, 16], activation="softplus", key=subkey ) if WITH_FLOW: key, *subkeys_flow = jax.random.split(key, 5) nn_flow = InvertibleNet( size=2, conditional_size=0, layers_list=[ CouplingLayer( size=2, conditional_size=0, num_splits=2, ode_layer_type=AffineFlowLayer, hidden_sizes=[12], activation="tanh", key=subkeys_flow[i], ) for i in range(3) ], ) # ── Pré-entraînement ICNN : u ≈ ½|x|² sur espace u-only ───────────────────── space_u = ApproximationSpace({"x": 2}, [(nn, "scalar", 1)], model_type="x_mu") print("Pré-entraînement u ≈ ½|x|² ...") _, space_u, key, _ = Projector( FunctionApproximator( main_domain=dx2, size=1, model_type="x_mu", f_rhs=lambda *args: jnp.array([0.5 * jnp.dot(args[0], args[0])]), ), space_u, TensorizedSampler( [DomainSampler(dx2), UniformParametricSampler(domain_mu)], bc=False ), ).project(space_u, key, 10, N_COLLOC) print("Pré-entraînement terminé.") plot_abstract_approx_space( space_u, dx2, parameters_domain=[], solution=lambda x, __mu: 0.5 * jnp.sum(x**2, axis=-1, keepdims=True), derivatives=["ux", "uy"], draw_contours=True, title=r"Pré-entraînement : $u \approx \frac{1}{2}|x|^2$", ) plt.show() # ── Assemblage de l'espace principal ────────────────────────────────────────── trained_nn = space_u.models[0] if WITH_FLOW: space = ApproximationSpace( {"x": 2}, [(trained_nn, "scalar", 1), (nn_flow, "field", 2)], model_type="x_mu" ) else: space = space_u # ── Sampler ──────────────────────────────────────────────────────────────────── sampler = TensorizedSampler( [DomainSampler(dx), UniformParametricSampler(domain_mu)], bc=True ) # ── Projector & entraînement MA ──────────────────────────────────────────────── LAM_FLOW = 0.5 # poids du terme de couplage T = ∇u (seulement si WITH_FLOW) if WITH_FLOW: # résidu intérieur size 3 : [MA, T₀-∂u/∂x₀, T₁-∂u/∂x₁] weights = {"interior": [1.0, LAM_FLOW, LAM_FLOW], "boundary": [6.0]} else: weights = {"interior": [1.0], "boundary": [6.0]} pinn = Projector(model, space, sampler, weights=weights) key, sample_dict = sampler.sample(key, N_COLLOC, N_BC_COLLOC) print(f"initial loss: {pinn.evaluate_loss(space, sample_dict):.4e}") start = timeit.default_timer() new_loss, nspace, key, loss_history = pinn.project( space, key, N_EPOCHS, N_COLLOC, N_BC_COLLOC ) end = timeit.default_timer() pinn.losses.loss_history = loss_history pinn.best_loss = new_loss pinn.space = nspace print(f"best loss: {new_loss['total']:.4e} | training time: {end - start:.1f}s") # ── Evaluate ∇u et T sur grille ──────────────────────────────────────────────── N_GRID = 100 xl = np.linspace(0.0, 1.0, N_GRID) X0, X1 = np.meshgrid(xl, xl) xy_flat = jnp.array(np.stack([X0.ravel(), X1.ravel()], axis=1)) mut = jnp.zeros((xy_flat.shape[0], 0)) variables = pinn.space.create_variables() u_var = variables[0] grad_u = u_var.grad_x() T_MA = grad_u.vmap_on_physical_variables() T_vals = jax.device_get(T_MA(pinn.space, xy_flat, mut)) TX = T_vals[:, 0].reshape(N_GRID, N_GRID) TY = T_vals[:, 1].reshape(N_GRID, N_GRID) G_grid = np.vectorize(lambda y0, y1: float(g(jnp.array([y0, y1]))[0]))(X0, X1) if WITH_FLOW: T_flow_func = variables[1] # RealNVP entraîné conjointement T_flow_vals = jax.device_get( T_flow_func.vmap_on_physical_variables()(pinn.space, xy_flat, mut) ) TX_flow = T_flow_vals[:, 0].reshape(N_GRID, N_GRID) TY_flow = T_flow_vals[:, 1].reshape(N_GRID, N_GRID) # ── Plots ────────────────────────────────────────────────────────────────────── fig, axes = plt.subplots(2, 3, figsize=(16, 10)) fig.suptitle("Density transport [Monge-Ampère]", fontsize=13) # 1. Source f(x) = 1 ax = axes[0, 0] ax.contourf(X0, X1, np.ones_like(X0), levels=5, cmap="Blues") ax.set_title(r"Source $f(x)=1$") ax.set_aspect("equal") ax.set_xlabel(r"$x_0$") ax.set_ylabel(r"$x_1$") # 2. Target g ax = axes[0, 1] cf = ax.contourf(X0, X1, G_grid, levels=30, cmap="Reds") ax.contour(X0, X1, G_grid, levels=10, colors="k", linewidths=0.4) fig.colorbar(cf, ax=ax) ax.set_title(r"Target $g(y)$") ax.set_aspect("equal") ax.set_xlabel(r"$y_0$") ax.set_ylabel(r"$y_1$") # 3. Image du maillage par ∇u (et T_flow si WITH_FLOW) ax = axes[0, 2] _n = 50 _t = jnp.linspace(0.0, 1.0, 300) _mu_l = jnp.zeros((300, 0)) for _c in np.linspace(0.0, 1.0, _n): ax.plot([_c, _c], [0.0, 1.0], color="0.82", lw=0.4, zorder=1) ax.plot([0.0, 1.0], [_c, _c], color="0.82", lw=0.4, zorder=1) # ∇u grid (bleu / rouge) _pts = jnp.stack([jnp.full(300, float(_c)), _t], axis=1) _Tl = jax.device_get(T_MA(pinn.space, _pts, _mu_l)) ax.plot(_Tl[:, 0], _Tl[:, 1], "b-", lw=0.8, zorder=2) _pts = jnp.stack([_t, jnp.full(300, float(_c))], axis=1) _Tl = jax.device_get(T_MA(pinn.space, _pts, _mu_l)) ax.plot(_Tl[:, 0], _Tl[:, 1], "r-", lw=0.8, zorder=2) if WITH_FLOW: # T_flow grid (vert / orange) _pts = jnp.stack([jnp.full(300, float(_c)), _t], axis=1) _Tl = jax.device_get( T_flow_func.vmap_on_physical_variables()(pinn.space, _pts, _mu_l) ) ax.plot(_Tl[:, 0], _Tl[:, 1], color="green", lw=0.6, alpha=0.6, zorder=3) _pts = jnp.stack([_t, jnp.full(300, float(_c))], axis=1) _Tl = jax.device_get( T_flow_func.vmap_on_physical_variables()(pinn.space, _pts, _mu_l) ) ax.plot(_Tl[:, 0], _Tl[:, 1], color="orange", lw=0.6, alpha=0.6, zorder=3) ax.set_xlim(-0.02, 1.02) ax.set_ylim(-0.02, 1.02) _title = r"$\nabla u$ (bleu/rouge)" + ( r" + $T_\mathrm{flow}$ (vert/orange)" if WITH_FLOW else "" ) ax.set_title(_title) ax.set_aspect("equal") ax.set_xlabel(r"$y_0$") ax.set_ylabel(r"$y_1$") # 4. Déplacement T₀(x) − x₀ D0 = TX - X0 ax = axes[1, 0] vmax0 = float(np.abs(D0).max()) cf = ax.pcolormesh( X0, X1, D0, cmap="RdBu_r", vmin=-vmax0, vmax=vmax0, shading="gouraud" ) fig.colorbar(cf, ax=ax) ax.set_title(r"$T_0(x) - x_0$ ($= \partial u/\partial x_0 - x_0$)") ax.set_aspect("equal") ax.set_xlabel(r"$x_0$") ax.set_ylabel(r"$x_1$") # 5. Déplacement T₁(x) − x₁ D1 = TY - X1 ax = axes[1, 1] vmax1 = float(np.abs(D1).max()) cf = ax.pcolormesh( X0, X1, D1, cmap="RdBu_r", vmin=-vmax1, vmax=vmax1, shading="gouraud" ) fig.colorbar(cf, ax=ax) ax.set_title(r"$T_1(x) - x_1$ ($= \partial u/\partial x_1 - x_1$)") ax.set_aspect("equal") ax.set_xlabel(r"$x_0$") ax.set_ylabel(r"$x_1$") # 6. Training loss ax = axes[1, 2] for label, hist in pinn.losses.loss_history.items(): if label == "total" or hist.shape[1] == 1: ax.semilogy(hist, label=label) else: for i in range(hist.shape[1]): ax.semilogy(hist[:, i], label=f"{label}[{i}]") ax.set_title("Training loss") ax.set_xlabel("epoch") ax.set_ylabel("loss") ax.legend(fontsize=7) ax.grid(True, which="both", alpha=0.4) plt.tight_layout() plt.show()