r"""Solves a 2D Poisson PDE with Dirichlet boundary conditions using PINNs. .. math:: -\delta u & = f in \Omega where :math:`x = (x_1, x_2) \in \Omega = (-1, 1) \times (-1, 1)` and :math:`f` such that :math:`u(x_1, x_2) = \sin(\pi x_1) \sin(\pi x_2)`, and :math:`g = 0`. The boundary conditions are homogeneous Dirichlet conditions enforced strongly. The neural network is a simple MLP (Multilayer Perceptron). The training is done 4 times: 1. with last_layer frozen, 2. with all but last layers frozen, 3. with nothing frozen, 4. with alternating 1. and 2., with all scimba optimizers. """ import sys import timeit import jax import jax.numpy as jnp import matplotlib.pyplot as plt from tqdm import tqdm from scimba_jax.domains.meshless_domains.domains_2d import Square2D from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import ( ApproximationSpace, ) from scimba_jax.nonlinear_approximation.integration.monte_carlo import ( DomainSampler, TensorizedSampler, ) from scimba_jax.nonlinear_approximation.networks.mlp import MLP from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector from scimba_jax.nonlinear_approximation.optimizers.losses import LOSS_VALUE_TYPE # import scimba_jax from scimba_jax.physical_models.elliptic_pde.laplacians import LaplacianDirichletND from scimba_jax.plots.plots_nd import ( plot_abstract_approx_spaces, ) from scimba_torch.utils.environment import ( get_static_terminal_width, is_static_width_environment, ) N_COLLOC = 1000 N_BC_COLLOC = 2000 N_EPOCHS = 50 N_EPOCHS_ADAM = 1000 N_EPOCHS_SSBFGS = 1000 N_LAYERS = [32, 32] N_LAYERS_ENG = [16, 16] key = jax.random.PRNGKey(0) def f_rhs(xy: jnp.ndarray) -> jnp.ndarray: x, y = xy[0:1], xy[1:2] return 2 * jnp.pi**2 * jnp.sin(jnp.pi * x) * jnp.sin(jnp.pi * y) def exact_sol(xy: jnp.ndarray) -> jnp.ndarray: x, y = xy[:, 0:1], xy[:, 1:2] return jnp.sin(jnp.pi * x) * jnp.sin(jnp.pi * y) def post_processing(approx: jnp.ndarray, xy: jnp.ndarray) -> jnp.ndarray: x, y = xy[0:1], xy[1:2] return approx * (x + 1.0) * (1.0 - x) * (y + 1.0) * (1.0 - y) def partition_last_active(space: ApproximationSpace): dyn = space.models[0].get_layer(-1) sta = space return dyn, sta def combine_last_active(layer, space: ApproximationSpace) -> ApproximationSpace: res = ApproximationSpace.__new__(ApproximationSpace) res.models = [MLP.set_layer(layer, -1, space.models[0])] res.dims = space.dims res.model_type = space.model_type res.types_models = space.types_models res.size_models = space.size_models res.pre_processings = space.pre_processings res.post_processings = space.post_processings res.ndof = space.ndof return res def partition_last_frozen(space: ApproximationSpace): dyn = space.models[0].get_layers(slice(0, -1)) sta = space return dyn, sta def combine_last_frozen(layer, space: ApproximationSpace) -> ApproximationSpace: res = ApproximationSpace.__new__(ApproximationSpace) res.models = [MLP.set_layers(layer, slice(0, -1), space.models[0])] res.dims = space.dims res.model_type = space.model_type res.types_models = space.types_models res.size_models = space.size_models res.pre_processings = space.pre_processings res.post_processings = space.post_processings res.ndof = space.ndof return res def get_key_model_space_sampler(n_layers: list[int]): domain_x = [(-1.0, 1.0), (-1.0, 1.0)] dx = Square2D(domain_x, is_main_domain=True) sampler = TensorizedSampler([DomainSampler(dx)], bc=True) key = jax.random.PRNGKey(0) nn = MLP(in_size=2, out_size=1, hidden_sizes=n_layers, key=key) space = ApproximationSpace( {"x": 2}, [(nn, "scalar", None)], model_type="x", ) model = LaplacianDirichletND( dx, lambda *args: f_rhs(*args), bc="weak", model_type="x" ) return key, model, space, sampler def solve( optimizer: str, n_epochs: int, n_layers: list[int], linesearch=None, frozen: str = "none", # in ["none", "last_active", "last_frozen", "alternate"] verbose=False, ): key, model, space, sampler = get_key_model_space_sampler(n_layers) kwargs = {} if linesearch is not None: kwargs = kwargs | {"linesearch": linesearch} if frozen == "alternate": kwargs1 = kwargs | { "partition": partition_last_active, "combine": combine_last_active, } kwargs2 = kwargs | { "partition": partition_last_frozen, "combine": combine_last_frozen, } pinn1 = Projector(model, space, sampler, optimizer=optimizer, **kwargs1) one_step_pinn1 = pinn1.build_one_step_optim(N_COLLOC, 1, 1) pinn2 = Projector(model, space, sampler, optimizer=optimizer, **kwargs2) one_step_pinn2 = pinn2.build_one_step_optim(N_COLLOC, 1, 1) print("@@@@@@@@@@@@@@@ train with %s @@@@@@@@@@@@@@@@@@@@@" % optimizer) # if verbose: key, sample_dict = sampler.sample(key, N_COLLOC) loss = pinn1.evaluate_loss(space, sample_dict) print("initial loss: ", loss) start = timeit.default_timer() init_best_loss: LOSS_VALUE_TYPE = pinn1.losses.get_infinity_losses() losses = pinn1.losses.set_initial_losses_history(n_epochs) new_space = space best_space = space best_loss = init_best_loss optimizer1 = pinn1.optimizer optimizer2 = pinn2.optimizer tqdm_ncols = get_static_terminal_width() tqdm_dynamic = not is_static_width_environment() tqdm_disable = verbose or (tqdm_ncols == 0) tqdm_position = kwargs.get("tqdm_position", 0) tqdm_desc = kwargs.get("tqdm_desc", "Training") tqdm_leave = kwargs.get("tqdm_leave", "True") loop = tqdm( total=n_epochs // 2, desc=tqdm_desc, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}" "[{elapsed}<{remaining}] {postfix}", disable=tqdm_disable, leave=tqdm_leave, position=tqdm_position, ascii=" |", file=sys.stdout, dynamic_ncols=tqdm_dynamic, ncols=tqdm_ncols, ) for i in range(n_epochs // 2): new_loss, new_space, key, optimizer1 = one_step_pinn1( new_space, key, optimizer1 ) losses = losses.update_losses_history(i, new_loss) new_loss, new_space, key, optimizer2 = one_step_pinn2( new_space, key, optimizer2 ) losses = losses.update_losses_history(i + 1, new_loss) if new_loss["total"] < best_loss["total"]: best_loss, best_space = new_loss, new_space if verbose: print("Epoch %d: New best loss: %.2e" % (i, new_loss["total"])) postfix_str = "loss: %.1e -> %.1e" % ( losses.losses_history["total"][1], best_loss["total"], ) loop.set_postfix_str(postfix_str) loop.update(1) loop.refresh() loop.close() pinn1.optimizer = optimizer1 pinn2.optimizer = optimizer2 pinn1.losses = losses pinn1.best_loss = best_loss pinn1.space = best_space end = timeit.default_timer() if verbose: print("best loss: ", best_loss) print("time for %d epochs: " % n_epochs, end - start) return pinn1 if frozen == "last_active": kwargs = kwargs | { "partition": partition_last_active, "combine": combine_last_active, } elif frozen == "last_frozen": kwargs = kwargs | { "partition": partition_last_frozen, "combine": combine_last_frozen, } pinn = Projector(model, space, sampler, optimizer=optimizer, **kwargs) if verbose: print("\n\n") print("@@@@@@@@@@@@@@@ create a pinn with strong bc @@@@@@@@@@@@@@@@@@@@@") # sample the domain key, sample_dict = sampler.sample(key, N_COLLOC) loss = pinn.evaluate_loss(space, sample_dict) print("initial loss: ", loss) losses_value = pinn.evaluate_losses(space, sample_dict) print("initial losses: ", losses_value) print("\n\n") print("@@@@@@@@@@@@@@@ train with %s @@@@@@@@@@@@@@@@@@@@@" % optimizer) start = timeit.default_timer() key, pinn = pinn.project(key, space, n_epochs, N_COLLOC) end = timeit.default_timer() if verbose: print("best loss: ", pinn.best_loss) print("time for %d epochs: " % n_epochs, end - start) return pinn def plot(pinns: tuple, title: str): domain_x = [(-1.0, 1.0), (-1.0, 1.0)] dx = Square2D(domain_x, is_main_domain=True) spaces = tuple(pinn.space for pinn in pinns) losses = tuple(pinn.losses for pinn in pinns) residuals = tuple(pinn.model for pinn in pinns) titles = tuple("optimizer: " + str(pinn.optimizer.__class__) for pinn in pinns) plot_abstract_approx_spaces( spaces, dx, # the spatial domain loss=losses, # the losses residual=residuals, # the models solution=exact_sol, # for plot of the exact sol: sol error=exact_sol, # for plot of the error with respect to a func: the func draw_contours=True, n_drawn_contours=20, title=title, titles=titles, ) plt.show() for frozen in ["none", "last_active", "last_frozen", "alternate"]: # for frozen in ["alternate"]: print("\n\n") print("@@@@@@@@@@@@@@@ model with %s @@@@@@@@@@@@@@@@@@@@@@" % frozen) pinns = tuple() opts = ["Adam", "L-BFGS", "SS-BFGS", "SS-Broyden", "ENG", "ANaGRAM"] n_ep = [ N_EPOCHS_ADAM, N_EPOCHS_ADAM, N_EPOCHS_SSBFGS, N_EPOCHS_SSBFGS, N_EPOCHS, N_EPOCHS, ] for opt, n in zip(opts, n_ep): pinn = solve(opt, n, N_LAYERS, frozen=frozen) pinns += (pinn,) if opt == "ENG": pinn = solve(opt, n, N_LAYERS, frozen=frozen, linesearch="armijo") pinns += (pinn,) plot(pinns[-3:], frozen)