r"""Stationary Navier-Stokes flow around a cylinder in a rectangular channel. Geometry: rectangle [0, 2.2] × [0, 0.41] with a circular obstacle at (0.2, 0.2), radius 0.05. Reference: https://arxiv.org/pdf/2402.03153 Boundary conditions are enforced as **hard constraints** via a post-processing function applied to the network output: * No-slip (south, north, cylinder): multiplied by (r−rc)·y·(H−y)·x * Inlet (west): parabolic profile added (hard-wired at x=0) * Outlet (east): p multiplied by (L−x) A single ``MODE`` flag selects the approximation-space layout: * ``"vec"`` — one network, output ``[u_x, u_y, p]`` * ``"field_scalar"`` — two networks: one field network for ``u``, one scalar network for ``p`` Uses :class:`NavierStokesND` from ``scimba_jax.physical_models.elliptic_pde.navier_stokes_stationary``. """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.domains_2d import Disk2D, Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.model_class.funcparam_vectorial import ( ParamVecFunction, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.abstract_residuals import static from scimba_jax.physical_models.elliptic_pde.navier_stokes_stationary import ( NavierStokesND, NavierStokesResidual, ) from scimba_jax.plots.plots_nd import plot_abstract_approx_space # ── configuration ───────────────────────────────────────────────────────────── # "vec" → one network, output [u_x, u_y, p] # "field_scalar" → two networks: field net for u, scalar net for p MODE = "field_scalar" # ── hyper-parameters ────────────────────────────────────────────────────────── N_COLLOC = 7000 N_EPOCHS = 1400 # channel geometry H, L = 0.41, 2.2 cx, cy = 0.2, 0.2 rc = 0.05 u0 = 0.3 nu = 2e-3 # same order as torch example (param=[0.0002, 0.0002]) # Loss weights: NS_x, NS_y, div → [1.0, 1.0, 10.0] WEIGHTS = { "interior": [1.0, 1.0, 10.0], } # ── weighted interior residual ──────────────────────────────────────────────── class WeightedNavierStokesResidual(NavierStokesResidual): """NS residual scaled by (r - rc) where r = dist(x, cylinder centre). Matches the torch implementation: each equation is returned as ``op * (r - rc)``, which is zero on the cylinder surface and grows away from it, improving the loss landscape near the obstacle. """ _wns_cx: float = static(0.2) _wns_cy: float = static(0.2) _wns_rc: float = static(0.05) def __init__(self, domain, cx: float, cy: float, rc_val: float, **kwargs): super().__init__(domain=domain, **kwargs) self._wns_cx = cx self._wns_cy = cy self._wns_rc = rc_val def construct_residual(self, *vars): base_res = super().construct_residual(*vars) cx_, cy_, rc_ = self._wns_cx, self._wns_cy, self._wns_rc def weight(x): r = jnp.sqrt((x[0] - cx_) ** 2 + (x[1] - cy_) ** 2) return jnp.reshape(r - rc_, (1,)) return ParamVecFunction.cat( [weight * base_res.component(i) for i in range(self.size)] ) # ── hard-constraint post-processing ────────────────────────────────────────── # Transforms raw network output so all BCs are satisfied exactly. # Signature: post_processing(output, x) (JAX convention in scimba). # # u: zero on south/north (y=0, y=H), cylinder (r=rc); equals parabolic # profile uin at inlet (x=0). # v: zero on south/north, cylinder, inlet. # p: zero at outlet (x=L). def post_processing_uv(output, x): """Hard-constraint post-processing for [u, v] (field or first-two components).""" x_, y_ = x[0], x[1] r = jnp.sqrt((x_ - cx) ** 2 + (y_ - cy) ** 2) r_inlet = jnp.sqrt((y_ - cy) ** 2 + cx**2) # dist from (0, y) to center uin = (4.0 / H**2) * y_ * (H - y_) * u0 u = output[0] * (r - rc) * y_ * (H - y_) * x_ + uin * ((r - rc) / (r_inlet - rc)) v = output[1] * (r - rc) * y_ * (H - y_) * x_ return jnp.array([u, v]) def post_processing_p(output, x): """Hard-constraint post-processing for p (scalar network).""" return jnp.array([output[0] * (L - x[0])]) def post_processing_vec(output, x): """Hard-constraint post-processing for [u, v, p] (vec network).""" uv = post_processing_uv(output[:2], x) p = post_processing_p(output[2:3], x) return jnp.concatenate([uv, p]) # ── domain ──────────────────────────────────────────────────────────────────── key = jax.random.PRNGKey(0) box = [(0.0, L), (0.0, H)] dx = Square2D(box, is_main_domain=True) hole = Disk2D((cx, cy), rc, is_main_domain=False) dx.add_hole(hole) # ── physics model ───────────────────────────────────────────────────────────── model = NavierStokesND(dx, dim=2, nu=nu, mode=MODE, bc_residuals={}) model.physical_residuals[dx.get_label()] = WeightedNavierStokesResidual( domain=dx, cx=cx, cy=cy, rc_val=rc, dim=2, nu=nu, mode=MODE ) # ── approximation space (depends on MODE) ───────────────────────────────────── if MODE == "vec": # One network → output [u_x, u_y, p] nn = MLP(in_size=2, out_size=3, hidden_sizes=[18, 18, 18], key=key) space = ApproximationSpace( {"x": 2}, [(nn, "vec", 3)], model_type="x", post_processing=post_processing_vec, ) else: # "field_scalar" # Two networks: field net for u = (u_x, u_y), scalar net for p nn_uv = MLP(in_size=2, out_size=2, hidden_sizes=[14, 14, 14, 14], key=key) nn_p = MLP(in_size=2, out_size=1, hidden_sizes=[12, 12], key=key) space = ApproximationSpace( {"x": 2}, [(nn_uv, "field", 2), (nn_p, "scalar", 1)], model_type="x", post_processing=[post_processing_uv, post_processing_p], ) # ── sampler ─────────────────────────────────────────────────────────────────── sampler = TensorizedSampler( [DomainSampler(dx)], bc=False, ) # ── training ────────────────────────────────────────────────────────────────── pinn = Projector(model, space, sampler, weights=WEIGHTS, one_loss_by_equation=True) key, sample_dict = sampler.sample(key, N_COLLOC, 0) loss0 = pinn.evaluate_loss(space, sample_dict) print("initial loss:", loss0) start = timeit.default_timer() key, pinn = pinn.project( key, space, N_EPOCHS, N_COLLOC, 0, ) end = timeit.default_timer() print(f"best loss: {pinn.best_loss}") print(f"time for {N_EPOCHS} epochs: {end - start:.1f}s") # ── visualisation ───────────────────────────────────────────────────────────── plot_abstract_approx_space( pinn.space, dx, components=[0, 1, 2], loss=pinn.losses, residual=pinn.model, draw_contours=True, n_drawn_contours=20, ) plt.show()