r"""Solves a 1D Helmholtz PDE with Dirichlet boundary conditions. .. math:: -u''(x) - k^2 u(x) & = f(x) in \Omega, \\ u(x) & = 0 on \partial \Omega, where :math:`\Omega = (0, 1)` and :math:`k` is a parameter. The right-hand side :math:`f(x)` is chosen such that the exact solution is :math:`u(x) = \sin(2 \pi x)`. The goal of this example is to compare the performance of different activation functions for the neural network used in the PINN. We run multiple seeds for each configuration and report the average errors and CPU times. """ # %% import time import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.nn_space import NNxSpace from scimba_torch.domain.meshless_domain.domain_1d import Segment1D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.elliptic_pde.pinns import ( NaturalGradientPinnsElliptic, PinnsElliptic, ) from scimba_torch.physical_models.elliptic_pde.laplacians import ( Laplacian2DDirichletStrongForm, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor torch.manual_seed(0) K_STAR = 2 * torch.pi def f_rhs(x: LabelTensor, mu: LabelTensor, k: float): x_ = x.get_components() return (K_STAR**2 - k**2) * torch.sin(K_STAR * x_) def exact_sol(x: LabelTensor, mu: LabelTensor): x_ = x.get_components() return torch.sin(K_STAR * x_) class Helmholtz1DDirichletStrongFormNoParam(Laplacian2DDirichletStrongForm): def __init__(self, space, f, k): super().__init__(space, f) self.k = k def operator(self, w, x, mu): u = w.get_components() u_x = self.grad(u, x) u_xx = self.grad(u_x, x) return -u_xx - self.k**2 * u def functional_operator(self, func, x, mu, theta): grad_u = torch.func.jacrev(func, 0) hessian_u = torch.func.jacrev(grad_u, 0, chunk_size=None)(x, mu, theta) laplacian_u = hessian_u[..., 0, 0] u = func(x, mu, theta) return -laplacian_u - self.k**2 * u def post_processing(inputs: torch.Tensor, x: LabelTensor, mu: LabelTensor): x_ = x.get_components() phi = x_ * (0.5 - x_) * (x_ - 1) return inputs * phi def functional_post_processing(u, x, mu, theta): return u(x, mu, theta) * x * (0.5 - x) * (x - 1) def create_pde(domain_x, sampler, activation_type, k): space = NNxSpace( 1, 0, GenericMLP, domain_x, sampler, layer_sizes=[16, 16], post_processing=post_processing, activation_type=activation_type, ) return Helmholtz1DDirichletStrongFormNoParam( space, f=lambda x, mu: f_rhs(x, mu, k), k=k ) def create_and_solve_pinn(activation_type, k, n_epochs=250, n_collocation=3000): domain_x = Segment1D([(0, 1)], is_main_domain=True) sampler_x = DomainSampler(domain_x) default_param_domain = [] default_param_sampler = UniformParametricSampler(default_param_domain) sampler = TensorizedSampler([sampler_x, default_param_sampler]) pde = create_pde(domain_x, sampler, activation_type, k) pinn_adam = PinnsElliptic( pde, bc_type="strong", functional_post_processing=functional_post_processing, adaptive_matrix_regularization=True, ) start = time.perf_counter() pinn_adam.solve(max_epochs=n_epochs, n_collocation=n_collocation) elapsed = time.perf_counter() - start loss_history = pinn_adam.losses.loss_history pinn = NaturalGradientPinnsElliptic( pde, bc_type="strong", functional_post_processing=functional_post_processing, adaptive_matrix_regularization=True, ) start = time.perf_counter() pinn.solve(max_epochs=n_epochs, n_collocation=n_collocation) elapsed += time.perf_counter() - start loss_history += pinn.losses.loss_history return pinn, elapsed, loss_history def compute_error(pinn, n_error=10_000): x = torch.linspace(0, 1, n_error).unsqueeze(-1) x = LabelTensor(x) mu = LabelTensor(torch.empty((n_error, 0))) with torch.no_grad(): u_pred = pinn.space.evaluate(x, mu).w.squeeze() u_exact = exact_sol(x, mu).squeeze() error = u_pred - u_exact l2_error = torch.sqrt((error**2).mean()) linf_error = error.abs().max() return l2_error, linf_error def plot_pinn(pinn, activation_type, k): plot_abstract_approx_spaces( (pinn.space,), pinn.space.spatial_domain, [[]], loss=(pinn.losses,), solution=exact_sol, error=exact_sol, title=f"results for k={int(k.item())} and activation={activation}", ) plt.show() # %% all_ks = 2 ** torch.arange(8, 9) all_activations = ["tanh", "sine", "gabor", "polynomial_bessel_j0", "sinc"] n_seeds = 10 n_epochs = 250 # %% shape = (len(all_ks), len(all_activations), n_seeds) l2_errors = torch.zeros(shape) linf_errors = torch.zeros(shape) cpu_times = torch.zeros(shape) loss_histories = torch.zeros(shape + (n_epochs * 2,)) for i_k, k in enumerate(all_ks): for i_a, activation in enumerate(all_activations): print(f"Solving for k={int(k.item())} and activation={activation}...") for i_s, seed in enumerate(range(n_seeds)): torch.manual_seed(42 + seed) pinn, elapsed, loss_history = create_and_solve_pinn( activation_type=activation, k=k.item() ) l2_errors[i_k, i_a, i_s], linf_errors[i_k, i_a, i_s] = compute_error(pinn) cpu_times[i_k, i_a, i_s] = elapsed loss_histories[i_k, i_a, i_s] = torch.tensor(loss_history) if i_s == 0: plot_pinn(pinn, activation, k) # %% if save_errors := False: torch.save(l2_errors, "helmholtz_1d_l2_errors.pt") torch.save(linf_errors, "helmholtz_1d_linf_errors.pt") torch.save(cpu_times, "helmholtz_1d_cpu_times.pt") torch.save(loss_histories, "helmholtz_1d_loss_histories.pt") # %% if save_errors := False: l2_errors = torch.load("helmholtz_1d_l2_errors.pt") linf_errors = torch.load("helmholtz_1d_linf_errors.pt") cpu_times = torch.load("helmholtz_1d_cpu_times.pt") loss_histories = torch.load("helmholtz_1d_loss_histories.pt") fig, ax = plt.subplots(1, 2, figsize=(12, 5)) for i_a, activation in enumerate(all_activations): avg_cpu_time = cpu_times[:, i_a, :].mean(dim=-1) label = f"{activation} (avg time: {avg_cpu_time.mean():.2f}s)" sorted_l2_errors = torch.log10(l2_errors[:, i_a, :].sort(dim=-1).values) sorted_linf_errors = torch.log10(linf_errors[:, i_a, :].sort(dim=-1).values) l2_mean = sorted_l2_errors[:, :5].mean(dim=-1) l2_std = sorted_l2_errors[:, :5].std(dim=-1) linf_mean = sorted_linf_errors[:, :5].mean(dim=-1) linf_std = sorted_linf_errors[:, :5].std(dim=-1) ks_ = torch.log10(all_ks) ax[0].plot(ks_, l2_mean, label=label, marker="x") ax[0].fill_between(ks_, l2_mean - l2_std, l2_mean + l2_std, alpha=0.2) ax[1].plot(ks_, linf_mean, label=label, marker="+") ax[1].fill_between(ks_, linf_mean - linf_std, linf_mean + linf_std, alpha=0.2) ax[0].set_xlabel("k") ax[0].set_xticks(ks_, labels=[int(k.item()) for k in all_ks]) ax[0].set_ylabel("L2 error") ax[0].set_title("L2 error vs k for different activations") ax[0].legend() ax[1].set_xlabel("k") ax[1].set_xticks(ks_, labels=[int(k.item()) for k in all_ks]) ax[1].set_ylabel("Linf error") ax[1].set_title("Linf error vs k for different activations") ax[1].legend() plt.suptitle( "Comparison of activations for 1D Helmholtz PDE: taking the best 5 out of 10 seeds" ) plt.tight_layout() plt.show() fig, ax = plt.subplots(3, 3, figsize=(18, 10)) idx_i, idx_j = 0, 0 for i_k, k in enumerate(all_ks): for i_a, activation in enumerate(all_activations): log_loss_history = torch.log10(loss_histories[i_k, i_a]) mean = log_loss_history.mean(dim=0) std = log_loss_history.std(dim=0) mps = torch.clip(mean + std, min=-15, max=10) mms = torch.clip(mean - std, min=-15, max=10) mean = torch.clip(mean, max=5) ax[idx_i, idx_j].plot(mean, label=activation) ax[idx_i, idx_j].fill_between(torch.arange(n_epochs * 2), mms, mps, alpha=0.2) ax[idx_i, idx_j].set_title(f"k={int(k.item())}") ax[idx_i, idx_j].set_xlabel("Epoch") ax[idx_i, idx_j].set_ylabel("Log10 Loss") ax[idx_i, idx_j].legend() idx_j += 1 if idx_j == 3: idx_j = 0 idx_i += 1 plt.tight_layout() plt.show() # %%