r"""Solves the Navier-Stokes equations in 2D using a PINN. The setting is described in https://arxiv.org/pdf/2402.03153 """ import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.base import SurfacicDomain, VolumetricDomain from scimba_jax.domains.meshless_domains.domains_2d import Disk2D, Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import ( UniformParametricSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.model_class.funcparam_vectorial import ( # ParamFieldFunction, ParamScalarFunction, ParamVecFunction, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.abstract_physical_model import ( AbstractPhysicalModel, ) from scimba_jax.physical_models.abstract_residuals import ( NDARRAYS_FUNC_TYPE, BoundaryResidual, # DataResidual, InteriorResidual, ) from scimba_jax.physical_models.initial_residuals import ( ProjectionResidual, ) from scimba_jax.plots.plots_nd import ( plot_abstract_approx_space, ) N_COLLOC = 8000 N_BC_COLLOC = 8000 N_IC_COLLOC = 8000 N_EPOCHS = 150 N_EPOCHS_SSBFGS = 1000 # The interior residual class NavierStokes2DResidual(InteriorResidual): def __init__(self, domain, time_domain, f_rhs): super().__init__( domain=domain, size=3, model_type="t_x_mu", f_rhs=f_rhs, time_domain=time_domain, ) assert self.dim == 3, "This class is designed for 2D Navier-Stokes equations." def construct_residual(self, *vars): uvp = vars[0] u = uvp.component(0) v = uvp.component(1) p = uvp.component(2) assert isinstance(u, ParamScalarFunction) assert isinstance(v, ParamScalarFunction) assert isinstance(p, ParamScalarFunction) # Time operator u_t = u.d_t() v_t = v.d_t() zero = ParamScalarFunction( p.dims, lambda model, *args: jnp.zeros((1,)), p.f_type ) time_op = ParamVecFunction.cat((zero, u_t, v_t)) # Space operator grad_u = u.gradient_x() grad_v = v.gradient_x() # Incompressibility condition op1 = grad_u.component(0) + grad_v.component(1) # Navier-Stokes equations grad_p = p.gradient_x() lapl_u = u.laplacian_x() lapl_v = v.laplacian_x() def nu(t, x, mu): return mu op2a = ( u * grad_u.component(0) + v * grad_u.component(1) + grad_p.component(0) - nu * lapl_u ) op2b = ( u * grad_v.component(0) + v * grad_v.component(1) + grad_p.component(1) - nu * lapl_v ) space_op = ParamVecFunction.cat((op1, op2a, op2b)) return time_op + space_op # The boundary residual: Dirichlet on a tuple of components class RestrainedDirichletResidual(BoundaryResidual): def __init__( self, domain: dict[str, SurfacicDomain], components: tuple[int, ...], model_type: str = "t_x_mu", f_rhs: NDARRAYS_FUNC_TYPE | None = None, time_domain: tuple[float, ...] | None = None, ): super().__init__( domain=domain, size=len(components), model_type=model_type, f_rhs=f_rhs, time_domain=time_domain, ) self.components = components def construct_residual(self, *vars): # Here assume that variables are functions of normals uvp = vars[0] to_cat = tuple(uvp.component(comp) for comp in self.components) return ParamVecFunction.cat(to_cat) # The model class NavierStokes2DEquation(AbstractPhysicalModel): def __init__( self, main_domain: VolumetricDomain, time_domain: tuple[float, float], f_bc_rhs: dict[str, NDARRAYS_FUNC_TYPE], f_rhs: NDARRAYS_FUNC_TYPE | None = None, f_ic_rhs: NDARRAYS_FUNC_TYPE | None = None, ): super().__init__(main_domain=main_domain, time_domain=time_domain) assert self.main_domain is not None self.physical_residuals = { self.main_domain.get_label(): NavierStokes2DResidual( domain=main_domain, time_domain=time_domain, f_rhs=f_rhs, ), } self.physical_residuals["bc in"] = RestrainedDirichletResidual( domain=self.boundaries["bc in"], components=(0, 1), model_type="t_x_mu", f_rhs=f_bc_rhs["bc in"], time_domain=time_domain, ) self.physical_residuals["bc res"] = RestrainedDirichletResidual( domain=self.boundaries["bc res"], components=(0, 1), model_type="t_x_mu", f_rhs=f_bc_rhs["bc res"], time_domain=time_domain, ) self.physical_residuals["bc out"] = RestrainedDirichletResidual( domain=self.boundaries["bc out"], components=(2,), model_type="t_x_mu", f_rhs=f_bc_rhs["bc out"], time_domain=time_domain, ) ic_label = "ic " + self.main_domain.get_label() self.physical_residuals[ic_label] = ProjectionResidual( domain=main_domain, time_domain=(self.time_domain[0],), size=3, model_type="t_x_mu", f_rhs=f_ic_rhs, ) f_bc_rhs = { "bc in": lambda *args: jnp.array([1.0, 0.0]), "bc res": lambda *args: jnp.array([0.0, 0.0]), "bc out": lambda *args: jnp.array([0.0]), } t_min, t_max = 0.0, 10.0 domain_t = (t_min, t_max) box = [(-2.5, 7.5), (-2.5, 2.5)] center = (0.0, 0.0) radius = 1.0 domain_mu = [(0.09, 0.11)] domain_x = Square2D(box, is_main_domain=True) hole = Disk2D(center, radius, is_main_domain=False, label_str="hole") domain_x.add_hole(hole) domain_x.set_boundaries_dict( { "bc in": ["bc west"], "bc res": ["bc south", "bc north", "bc hole circle"], "bc out": ["bc east"], } ) sampler = TensorizedSampler( [ UniformTimeSampler(domain_t), DomainSampler(domain_x), UniformParametricSampler(domain_mu), ], model_type="t_x_mu", bc=True, ic=True, ) # sample the domain key = jax.random.PRNGKey(0) key, sample_dict = sampler.sample(key, N_COLLOC) print("\n\n") print("@@@@@@@@@@@@@@@ create a pinn with weak bc @@@@@@@@@@@@@@@@@@@@@") key = jax.random.PRNGKey(0) nn = MLP(in_size=4, out_size=3, hidden_sizes=[16, 32, 16], key=key) space = ApproximationSpace( {"t": 1, "x": 2, "mu": 1}, [(nn, "vec", 3)], model_type="t_x_mu" ) model = NavierStokes2DEquation(domain_x, domain_t, f_bc_rhs=f_bc_rhs) # custom_weights = { "interior": [10.0, 10.0, 1.0], "bc in": [10.0, 10.0], "bc res": [10.0, 10.0], "bc out": [10.0], "ic interior": [10.0, 10.0, 1.0], } pinn = Projector(model, space, sampler, weights=custom_weights) loss = pinn.evaluate_loss(space, sample_dict) print("initial loss: ", loss) losses = pinn.evaluate_losses(space, sample_dict) print("initial losses: ", losses) print("\n\n") print("@@@@@@@@@@@@@@@ train with Energy Natural Gradient @@@@@@@@@@@@@@@@@@@@@") start = timeit.default_timer() key, pinn = pinn.project(key, space, N_EPOCHS, N_COLLOC, N_BC_COLLOC, N_IC_COLLOC) end = timeit.default_timer() print("best loss: ", pinn.best_loss) print("time for %d epochs: " % N_EPOCHS, end - start) def construct_vorticity(*vars): uvp = vars[0] u = uvp.component(0) v = uvp.component(1) assert isinstance(u, ParamScalarFunction) assert isinstance(v, ParamScalarFunction) grad_u = u.gradient_x() grad_v = v.gradient_x() vorticity = grad_u.component(0) - grad_v.component(1) return vorticity def construct_norm_of_speed(*vars): uvp = vars[0] u = uvp.component(0) v = uvp.component(1) assert isinstance(u, ParamScalarFunction) assert isinstance(v, ParamScalarFunction) sqr_norm = u**2 + v**2 norm = ParamScalarFunction( sqr_norm.dims, lambda model, *args: jnp.sqrt(sqr_norm(model, *args)), sqr_norm.f_type, ) return norm additional_scalar_functions = { "$\\partial_x v - \\partial_y u$": construct_vorticity, "$\\|(u,v)\\|_2$": construct_norm_of_speed, } plot_abstract_approx_space( pinn.space, # the approximation space domain_x, # the spatial domain domain_mu, # the parameter's domain domain_t, components=[0, 1, 2], loss=pinn.losses, # for plot of the loss: the losses residual=pinn.model, draw_contours=True, n_drawn_contours=20, parameters_values="mean", loss_groups=["bc", "interior"], additional_scalar_functions=additional_scalar_functions, ) plt.show() print("\n\n") print("@@@@@@@@@@@@@@@ train with ss-bfgs @@@@@@@@@@@@@@@@@@@@@@") key = jax.random.PRNGKey(0) nn2 = MLP(in_size=4, out_size=3, hidden_sizes=[16, 32, 16], key=key) space2 = ApproximationSpace( {"t": 1, "x": 2, "mu": 1}, [(nn2, "vec", 3)], model_type="t_x_mu" ) pinn2 = Projector(model, space2, sampler, weights=custom_weights, optimizer="SS-BFGS") start = timeit.default_timer() key, pinn2 = pinn2.project( key, space2, N_EPOCHS_SSBFGS, N_COLLOC, N_BC_COLLOC, N_IC_COLLOC ) end = timeit.default_timer() print("best loss: ", pinn2.best_loss) print("time for %d epochs: " % N_EPOCHS_SSBFGS, end - start) plot_abstract_approx_space( pinn2.space, # the approximation space domain_x, # the spatial domain domain_mu, # the parameter's domain domain_t, components=[0, 1, 2], loss=pinn2.losses, # for plot of the loss: the losses residual=pinn2.model, draw_contours=True, n_drawn_contours=20, parameters_values="mean", loss_groups=["bc", "interior"], additional_scalar_functions=additional_scalar_functions, ) plt.show()