# %% import copy import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from scimba_jax.domains.meshless_domains.base import VolumetricDomain from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.densityflow_approximation_spaces import ( DensityFlowApproximationSpace, DensityFlowInvertibleApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import ( UniformTimeSampler, ) from scimba_jax.nonlinear_approximation.model_class.funcparam_vectorial import ( ParamScalarFunction, ParamVecFunction, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.networks.structure_preserving_nets.affine_ode_layers import ( # noqa: E501 AffineFlowLayer, ) from scimba_jax.nonlinear_approximation.networks.structure_preserving_nets.coupling_layers import ( # noqa: E501 CouplingLayer, ) from scimba_jax.nonlinear_approximation.networks.structure_preserving_nets.invertible_nn import ( # noqa: E501 InvertibleNet, ) from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.physical_models.abstract_physical_model import ( PHYSICAL_RESIDUALS_TYPE, AbstractPhysicalModel, ) from scimba_jax.physical_models.abstract_residuals import ( DOMAIN_TYPE, NDARRAYS_FUNC_TYPE, PARAM_FUNC_TYPE, InitialResidual, InteriorResidual, ) class RotatingTransportResidualInvertibleLossFlow(InteriorResidual): def __init__( self, domain: VolumetricDomain, time_domain: tuple[float, float], f_rhs: NDARRAYS_FUNC_TYPE | None = None, Tf: float = 1.0, ): super().__init__( domain=domain, size=2, model_type="x_t", f_rhs=f_rhs, time_domain=time_domain, ) self.Tf = Tf def velocity(self, x): v1 = -2.0 * jnp.pi * x[1] v2 = 2.0 * jnp.pi * x[0] return jnp.array([v1, v2]) def construct_residual( self, phi: PARAM_FUNC_TYPE, inv_phi: PARAM_FUNC_TYPE, rho: PARAM_FUNC_TYPE ) -> PARAM_FUNC_TYPE: phi1, phi2 = inv_phi.components() dx_phi1 = phi1.gradient_x() dx_phi2 = phi2.gradient_x() dt_phi1 = phi1.d_t() dt_phi2 = phi2.d_t() return ParamVecFunction.cat( [dt_phi1 + dx_phi1.dot(self.velocity), dt_phi2 + dx_phi2.dot(self.velocity)] ) class RotatingTransportResidualMLPLossFlow(InteriorResidual): def __init__( self, domain: VolumetricDomain, time_domain: tuple[float, float], f_rhs: NDARRAYS_FUNC_TYPE | None = None, Tf: float = 1.0, ): super().__init__( domain=domain, size=2, model_type="x_t", f_rhs=f_rhs, time_domain=time_domain, ) self.Tf = Tf def velocity(self, x): v1 = -2.0 * jnp.pi * x[1] v2 = 2.0 * jnp.pi * x[0] return jnp.array([v1, v2]) def construct_residual( self, inv_phi: PARAM_FUNC_TYPE, rho: PARAM_FUNC_TYPE ) -> PARAM_FUNC_TYPE: phi1, phi2 = inv_phi.components() dx_phi1 = phi1.gradient_x() dx_phi2 = phi2.gradient_x() dt_phi1 = phi1.d_t() dt_phi2 = phi2.d_t() return ParamVecFunction.cat( [dt_phi1 + dx_phi1.dot(self.velocity), dt_phi2 + dx_phi2.dot(self.velocity)] ) class RotatingTransportResidualInvertibleLossDensity(InteriorResidual): def __init__( self, domain: VolumetricDomain, time_domain: tuple[float, float], f_rhs: NDARRAYS_FUNC_TYPE | None = None, Tf: float = 1.0, ): super().__init__( domain=domain, size=1, model_type="x_t", f_rhs=f_rhs, time_domain=time_domain, ) self.Tf = Tf def velocity(self, x): v1 = -2.0 * jnp.pi * x[1] v2 = 2.0 * jnp.pi * x[0] return jnp.array([v1, v2]) def construct_residual( self, phi: PARAM_FUNC_TYPE, inv_phi: PARAM_FUNC_TYPE, rho: PARAM_FUNC_TYPE ) -> PARAM_FUNC_TYPE: dx_rho = rho.gradient_x() dt_rho = rho.d_t() return ParamVecFunction.cat([dt_rho + dx_rho.dot(self.velocity)]) class RotatingTransportResidualMLPLossDensity(InteriorResidual): def __init__( self, domain: VolumetricDomain, time_domain: tuple[float, float], f_rhs: NDARRAYS_FUNC_TYPE | None = None, Tf: float = 1.0, ): super().__init__( domain=domain, size=1, model_type="x_t", f_rhs=f_rhs, time_domain=time_domain, ) self.Tf = Tf def velocity(self, x): v1 = -2.0 * jnp.pi * x[1] v2 = 2.0 * jnp.pi * x[0] return jnp.array([v1, v2]) def construct_residual( self, inv_phi: PARAM_FUNC_TYPE, rho: PARAM_FUNC_TYPE ) -> PARAM_FUNC_TYPE: dx_rho = rho.gradient_x() dt_rho = rho.d_t() return ParamVecFunction.cat([dt_rho + dx_rho.dot(self.velocity)]) class ProjectionResidualInvertibleLossFlow(InitialResidual): def __init__( self, domain: DOMAIN_TYPE, time_domain: float | tuple[float] = (0.0,), size: int = 2, model_type: str = "x_t", f_rhs: NDARRAYS_FUNC_TYPE | None = None, ): super().__init__( domain=domain, time_domain=( (time_domain,) if isinstance(time_domain, float) else time_domain ), size=size, model_type=model_type, f_rhs=f_rhs, ) def construct_residual( self, phi: PARAM_FUNC_TYPE, inv_phi: PARAM_FUNC_TYPE, rho: PARAM_FUNC_TYPE ) -> PARAM_FUNC_TYPE: assert isinstance(rho, ParamScalarFunction) or isinstance(rho, ParamVecFunction) return inv_phi.set_t_0(self.time_domain[0]) class ProjectionResidualMLPLossFlow(InitialResidual): def __init__( self, domain: DOMAIN_TYPE, time_domain: float | tuple[float] = (0.0,), size: int = 2, model_type: str = "x_t", f_rhs: NDARRAYS_FUNC_TYPE | None = None, ): super().__init__( domain=domain, time_domain=( (time_domain,) if isinstance(time_domain, float) else time_domain ), size=size, model_type=model_type, f_rhs=f_rhs, ) def construct_residual( self, inv_phi: PARAM_FUNC_TYPE, rho: PARAM_FUNC_TYPE ) -> PARAM_FUNC_TYPE: assert isinstance(rho, ParamScalarFunction) or isinstance(rho, ParamVecFunction) return inv_phi.set_t_0(self.time_domain[0]) class ProjectionResidualInvertibleLossDensity(InitialResidual): def __init__( self, domain: DOMAIN_TYPE, time_domain: float | tuple[float] = (0.0,), size: int = 1, model_type: str = "x_t", f_rhs: NDARRAYS_FUNC_TYPE | None = None, ): super().__init__( domain=domain, time_domain=( (time_domain,) if isinstance(time_domain, float) else time_domain ), size=size, model_type=model_type, f_rhs=f_rhs, ) def construct_residual( self, phi: PARAM_FUNC_TYPE, inv_phi: PARAM_FUNC_TYPE, rho: PARAM_FUNC_TYPE ) -> PARAM_FUNC_TYPE: assert isinstance(rho, ParamScalarFunction) or isinstance(rho, ParamVecFunction) return rho.set_t_0(self.time_domain[0]) class ProjectionResidualMLPLossDensity(InitialResidual): def __init__( self, domain: DOMAIN_TYPE, time_domain: float | tuple[float] = (0.0,), size: int = 1, model_type: str = "x_t", f_rhs: NDARRAYS_FUNC_TYPE | None = None, ): super().__init__( domain=domain, time_domain=( (time_domain,) if isinstance(time_domain, float) else time_domain ), size=size, model_type=model_type, f_rhs=f_rhs, ) def construct_residual( self, inv_phi: PARAM_FUNC_TYPE, rho: PARAM_FUNC_TYPE ) -> PARAM_FUNC_TYPE: assert isinstance(rho, ParamScalarFunction) or isinstance(rho, ParamVecFunction) return rho.set_t_0(self.time_domain[0]) class RotatingTransportTimeInterval(AbstractPhysicalModel): def __init__( self, main_domain: VolumetricDomain, time_domain: tuple[float, float], f_rhs: NDARRAYS_FUNC_TYPE | None = None, bc: str = "strong", ic: str = "weak", f_ic_rhs: NDARRAYS_FUNC_TYPE | None = None, type_loss: str = "flow", type_net: str = "MLP", Tf=8.0, ): super().__init__(main_domain=main_domain, time_domain=time_domain) if type_loss == "flow": if type_net == "MLP": residual = RotatingTransportResidualMLPLossFlow ic_residual = ProjectionResidualMLPLossFlow else: residual = RotatingTransportResidualInvertibleLossFlow ic_residual = ProjectionResidualInvertibleLossFlow size = 2 else: if type_net == "MLP": residual = RotatingTransportResidualMLPLossDensity ic_residual = ProjectionResidualMLPLossDensity else: residual = RotatingTransportResidualInvertibleLossDensity ic_residual = ProjectionResidualInvertibleLossDensity size = 1 self.physical_residuals: PHYSICAL_RESIDUALS_TYPE = { self.main_domain.get_label(): residual( domain=main_domain, time_domain=time_domain, f_rhs=f_rhs, Tf=Tf, ), } if ic == "weak": label = "ic " + self.main_domain.get_label() self.physical_residuals[label] = ic_residual( domain=main_domain, time_domain=time_domain, size=size, model_type="x_t", f_rhs=f_ic_rhs, ) def renew_time_domain(self, new_time_domain: tuple[float, float]): """Renew the time domain of the physical model and of all its residuals.""" self.time_domain = new_time_domain for label in self.physical_residuals: self.physical_residuals[label].time_domain = new_time_domain def create_net(type_net, size, conditional_size, key, hidden_size=14, num_layers=8): """Crée un InvertibleNet avec des CouplingLayers.""" if type_net == "invertible": layers_list = [ CouplingLayer( size=size, conditional_size=conditional_size, num_splits=2, ode_layer_type=AffineFlowLayer, hidden_sizes=[hidden_size], activation="tanh", key=key, ) for _ in range(num_layers) ] model = InvertibleNet( size=size, conditional_size=conditional_size, layers_list=layers_list, ) else: model = MLP( in_size=dim + time_and_params_dim, hidden_sizes=[hidden_size] * num_layers, out_size=dim, activation="sine", ) return model def exact_density(t, x): nu = 0.08 c = 0.3 x1, x2 = x[0], x[1] x1t, x2t = jnp.cos(2 * jnp.pi * t), jnp.sin(2 * jnp.pi * t) r2 = (x1 - c * x1t) ** 2 + (x2 - c * x2t) ** 2 return jnp.exp(-r2 / (2 * nu**2)) def initial_density(x): zeros = jnp.zeros_like(x[0:1]) return exact_density(zeros, x) def initial_solution_flow(x): return x def initial_solution_density(x): return initial_density(x) def plot_losses_and_density_single_pinn( domain_x, pinn, flow_number, type_network=None, n_visu=256 ): for label in pinn.losses.losses_history: if (label == "total") or (pinn.losses.losses_history[label].shape[1] == 1): plt.semilogy(pinn.losses.losses_history[label], label=label + " loss") else: for i in range(pinn.losses.losses_history[label].shape[1]): plt.semilogy( pinn.losses.losses_history[label][:, i], label=label + "%d loss" % i, ) plt.title(f"loss history for flow {flow_number}") plt.legend() plt.show() if type_network == "MLP": id_density = 2 else: id_density = 4 sh = (n_visu, n_visu) x1_lin = jnp.linspace(-1, 1, n_visu) x2_lin = jnp.linspace(-1, 1, n_visu) x1, x2 = jnp.meshgrid(x1_lin, x2_lin) x = jnp.stack([x1.ravel(), x2.ravel()], axis=-1) r2 = x1**2 + x2**2 mask = r2 > 1.0 def plot_rho(a, title, t_val): t = jnp.ones_like(x[:, 0:1]) * t_val print( f"in plot: pinn.space.idx_current_flow={pinn.space.idx_current_flow}, t_val={t_val}" ) rho = pinn.evaluate(x, t)[:, id_density : id_density + 1].reshape(*sh) rho = jnp.where(mask, jnp.nan, rho) im = a.contourf(x1, x2, rho, levels=64, cmap="turbo") plt.colorbar(im, ax=a) a.set_aspect("equal") a.contour(x1, x2, rho, levels=10, colors="white", linewidths=0.5) a.set_aspect("equal") min_rho = float(rho.min()) max_rho = float(rho.max()) title = f"flow {flow_number}: {title} (min={min_rho:.2f}, max={max_rho:.2f})" a.set_title(title) fig, ax = plt.subplots(3, 2, figsize=(12, 18)) plot_rho(ax[0, 0], "t=0.0", 0.0) plot_rho(ax[0, 1], "t=0.25", 0.25) plot_rho(ax[1, 0], "t=0.45", 0.45) plot_rho(ax[1, 1], "t=0.49", 0.49) plot_rho(ax[2, 0], "t=0.75", 0.75) plot_rho(ax[2, 1], "t=1.0", 1.0) plt.tight_layout() plt.show() def plot_losses_and_density(domain_x, list_of_pinns, type_network=None, n_visu=256): for i, pinn in enumerate(list_of_pinns): pinn.space.inference_mode = True old_idx = pinn.space.idx_current_flow pinn.space.idx_current_flow = i plot_losses_and_density_single_pinn( domain_x, pinn, flow_number=i, type_network=type_network, n_visu=n_visu ) pinn.space.idx_current_flow = old_idx pinn.space.inference_mode = False def plot_single_flow(domain_x, pinn, flow_number, type_network=None, n_visu=256): sh = (n_visu, n_visu, 2) x1_lin = jnp.linspace(-1, 1, n_visu) x2_lin = jnp.linspace(-1, 1, n_visu) x1, x2 = jnp.meshgrid(x1_lin, x2_lin) x = jnp.stack([x1.ravel(), x2.ravel()], axis=-1) r2 = x1**2 + x2**2 mask = (r2 > 1.0)[:, :, None] if type_network == "MLP": id_bwd = 0 else: id_bwd = 2 def plot_flow(ax, title, t_val): t = jnp.ones_like(x[:, 0:1]) * t_val flow_bwd = pinn.evaluate(x, t)[:, id_bwd : id_bwd + 2].reshape(*sh) flow_bwd = jnp.where(mask, jnp.nan, flow_bwd) norm_flow_bwd = jnp.linalg.norm(flow_bwd, axis=-1) im = ax[0].contourf(x1, x2, norm_flow_bwd, levels=64, cmap="turbo") plt.colorbar(im, ax=ax[0]) ax[0].set_aspect("equal") ax[0].set_title(f"flow {flow_number}, {title}: (norm)") im = ax[1].contourf(x1, x2, flow_bwd[..., 0], levels=64, cmap="turbo") plt.colorbar(im, ax=ax[1]) ax[1].set_aspect("equal") ax[1].set_title(f"flow {flow_number}, {title}: (x)") im = ax[2].contourf(x1, x2, flow_bwd[..., 1], levels=64, cmap="turbo") plt.colorbar(im, ax=ax[2]) ax[2].set_aspect("equal") ax[2].set_title(f"flow {flow_number}, {title}: (y)") fig, ax = plt.subplots(5, 3, figsize=(9, 6)) plot_flow(ax[0], "t=0.0", 0.0) plot_flow(ax[1], "t=0.25", 0.25) plot_flow(ax[2], "t=0.5", 0.5) plot_flow(ax[3], "t=0.75", 0.75) plot_flow(ax[4], "t=1.0", 1.0) plt.tight_layout() plt.show() def plot_flows(domain_x, pinn, type_network=None, n_visu=256): for i, pinn in enumerate(pinn): old_idx = pinn.space.idx_current_flow pinn.space.idx_current_flow = i plot_single_flow( domain_x, pinn, flow_number=i, type_network=type_network, n_visu=n_visu ) pinn.space.idx_current_flow = old_idx # %% N_COLLOC = 8_000 N_IC_COLLOC = 4_000 N_EPOCHS = 100 domain_x = [(0.0, 1.0), (0.0, 1.0)] dx = Square2D(domain_x, is_main_domain=True) key = jax.random.PRNGKey(0) dim = 2 params_dim = 0 nb_models = 1 keys = jax.random.split(key, nb_models) time_and_params_dim = 1 + params_dim ############# important type_loss = "flow" type_net = "MLP" models = [ create_net( type_net=type_net, size=dim, conditional_size=time_and_params_dim, key=keys[0], hidden_size=14, num_layers=5, ) for _ in range(nb_models) ] Tf = 1.0 Deltat = Tf / nb_models time_intervals = [(Deltat * i, Deltat * (i + 1)) for i in range(nb_models)] if type_net == "invertible": space = DensityFlowInvertibleApproximationSpace( dim=dim, params_dim=params_dim, models=models, model_type="x_t", initial_density=initial_density, time_intervals=time_intervals, ) else: space = DensityFlowApproximationSpace( dim=dim, params_dim=params_dim, models=models, model_type="x_t", initial_density=initial_density, time_intervals=time_intervals, ) if type_loss == "flow": weights = {"interior": [0.3, 0.3], "ic interior": [5.0, 5.0]} else: weights = {"interior": [0.3], "ic interior": [5.0]} initial_domain_t = (0.0, time_intervals[0][1]) sampler = TensorizedSampler( [ DomainSampler(dx), UniformTimeSampler(initial_domain_t), ], bc=False, ic=True, model_type="x_t", ) model = RotatingTransportTimeInterval( main_domain=dx, time_domain=initial_domain_t, f_ic_rhs=initial_solution_flow, Tf=8, type_loss=type_loss, type_net=type_net, ) list_of_pinns = [] for i in range(nb_models): space.idx_current_flow = i domain_t = time_intervals[i] print( f"Training flow {i + 1}/{nb_models} on time interval {domain_t} with {type_loss} loss and {type_net} network" ) sampler.renew_sampler(UniformTimeSampler(domain_t), "t") model.renew_time_domain(domain_t) pinn = Projector( model, space, sampler, optimizer="ENG", weights=weights, one_loss_per_residual=True, matrix_regularization=1e-5 * (i + 1), ) start = timeit.default_timer() key, pinn = pinn.project( key, space, n_epochs=N_EPOCHS, n_colloc=N_COLLOC, n_ic_colloc=N_IC_COLLOC ) end = timeit.default_timer() list_of_pinns.append(copy.deepcopy(pinn)) print("best loss: ", pinn.best_loss) print("time for %d epochs: " % N_EPOCHS, end - start) # %% if plot := True: plot_losses_and_density(domain_x, list_of_pinns, type_network=type_net, n_visu=256) plot_flows(domain_x, list_of_pinns, type_network=type_net, n_visu=256) # %%