r"""Solves a damped harmonic oscillator ODE using a PINN. .. math:: \frac{dx}{dt} = - \mu x - y, \\ \frac{dy}{dt} = x - \mu y, \\ where :math:`(x, y): (0, T) \times (\mu_{\min}, \mu_{\max}) \to \mathbb{R}^2` is the unknown function, with :math:`(0, T) \subset \mathbb{R}` the time domain and :math:`(\mu_{\min}, \mu_{\max}) \subset \mathbb{R}` the parameter domain. The initial condition is :math:`(x(0, \mu), y(0, \mu)) = (1, 0)` for all :math:`\mu \in (\mu_{\min}, \mu_{\max})`. The exact solution is .. math:: x(t, \mu) = e^{-\mu t} \cos(t), \\ y(t, \mu) = e^{-\mu t} \sin(t). Two training strategies are compared: energy natural gradient with weak IC and energy natural gradient with strong IC. """ # %% import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.nn_space import NNtSpace from scimba_torch.integration.monte_carlo import TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.integration.monte_carlo_time import UniformTimeSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import ( GenericMLP, ) from scimba_torch.numerical_solvers.ode.pinns import ( NaturalGradientODEPinns, ) from scimba_torch.physical_models.ode.damped_harmonic_oscillator import ( DampedHarmonicOscillator, ) from scimba_torch.utils.scimba_tensors import LabelTensor # %% N_EQUATIONS = 2 N_PARAMETERS = 1 VAR_NAMES = ["x", "y"] def exact_sol(t, mu): x = torch.exp(-mu * t) * torch.cos(t) y = torch.exp(-mu * t) * torch.sin(t) return torch.stack([x, y], dim=-1) def f_ini(mu): return torch.tensor([1.0, 0.0]).unsqueeze(0).expand(mu.shape[0], -1) def exact(t, mu): t_ = t.get_components() mu_ = mu.get_components() x = torch.exp(-mu_ * t_) * torch.cos(t_) y = torch.exp(-mu_ * t_) * torch.sin(t_) return torch.cat([x, y], dim=-1) t_min, t_max = 0.0, 3 * torch.pi mu_min, mu_max = 0.0, 0.5 sampler = TensorizedSampler( [UniformTimeSampler((t_min, t_max)), UniformParametricSampler([(mu_min, mu_max)])] ) def post_processing( inputs: torch.Tensor, t: LabelTensor, mu: LabelTensor ) -> torch.Tensor: t_ = t.get_components() return f_ini(mu) + t_ * inputs def functional_post_processing( func, t: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor ) -> torch.Tensor: return f_ini(mu).squeeze() + t[0] * func(t, mu, theta) def create_ode(): space = NNtSpace( N_EQUATIONS, N_PARAMETERS, GenericMLP, sampler, layer_sizes=[10, 10] ) return DampedHarmonicOscillator(space, init=f_ini) def create_ode_pp(): space = NNtSpace( N_EQUATIONS, N_PARAMETERS, GenericMLP, sampler, layer_sizes=[10, 10], post_processing=post_processing, ) return DampedHarmonicOscillator(space, init=f_ini) def plot_pinn(pinn, title: str, mu_val: float = (mu_min + mu_max) / 2): """Plot the results of the PINN.""" with torch.no_grad(): fig, ax = plt.subplots(N_EQUATIONS, 3, figsize=(15, 5 + 2 * N_EQUATIONS)) t = LabelTensor(torch.linspace(t_min, t_max, 1000)[:, None]) mu = LabelTensor(mu_val * torch.ones_like(t.x)) u = pinn.evaluate(t, mu).w u_exact = exact(t, mu) error = u - u_exact pinn.losses.plot(ax[0, 0]) for i in range(N_EQUATIONS): ax[i, 1].plot(t.x, u[:, i], label=f"approximation, {VAR_NAMES[i]}(t)") ax[i, 1].plot(t.x, u_exact[:, i], label=f"exact, {VAR_NAMES[i]}(t)") ax[i, 1].set_title(f"{title}") ax[i, 1].legend() ax[i, 2].plot(t.x, error[:, i], label=f"error, {VAR_NAMES[i]}(t)") ax[i, 2].set_title(f"Error, L2 = {torch.sqrt(torch.mean(error**2)):.2e}") ax[i, 2].legend() plt.show() # %% pinn_eng_weak = NaturalGradientODEPinns( create_ode(), ic_type="weak", ic_weight=250, matrix_regularization=1e-4, one_loss_by_equation=True, ) resume_solve = True if resume_solve or not pinn_eng_weak.load(__file__, "simple_ode_eng"): pinn_eng_weak.solve(epochs=200, n_collocation=2000, n_ic_collocation=500) pinn_eng_weak.save(__file__, "simple_ode_eng") plot_pinn(pinn_eng_weak, "ENG with weak IC") # %% pinn_eng_strong = NaturalGradientODEPinns( create_ode_pp(), matrix_regularization=1e-4, functional_post_processing=functional_post_processing, one_loss_by_equation=True, ) resume_solve = True if resume_solve or not pinn_eng_strong.load(__file__, "simple_ode_eng_strong"): pinn_eng_strong.solve(epochs=200, n_collocation=1000) pinn_eng_strong.save(__file__, "simple_ode_eng_strong") plot_pinn(pinn_eng_strong, "ENG with strong IC") # %%