"""Projection of a 1D time-dependent function using preconditioned methods.""" # %% import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.nn_space import NNxtSpace from scimba_torch.domain.meshless_domain.domain_1d import Segment1D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.integration.monte_carlo_time import UniformTimeSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.collocation_projector import ( AnagramProjector, NaturalGradientProjector, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces def exact(t, x, mu): x1 = x.get_components() t1 = t.get_components() return torch.exp(-(t1 * torch.pi**2) / 4.0) * torch.sin(torch.pi * x1) domain_x = Segment1D((0, 1), is_main_domain=True) t_min, t_max = 0.0, 1.0 sampler = TensorizedSampler( [ UniformTimeSampler((t_min, t_max)), DomainSampler(domain_x), UniformParametricSampler([]), ] ) print("\n\n") print(" ################################################# ") print(" # Energy natural gradient # ") print(" ################################################# ") space2 = NNxtSpace(1, 0, GenericMLP, domain_x, sampler, layer_sizes=[64]) print("ndof", space2.ndof) # %% # %% p2 = NaturalGradientProjector( space2, exact, type_linesearch="armijo", data_linesearch={"n_step_max": 10, "alpha": 0.01}, ) new_solve = False if new_solve or not p2.load(__file__, "natural"): p2.solve(epochs=200, n_collocation=900, verbose=True) p2.save(__file__, "natural") print("\n\n") print(" ################################################# ") print(" # Anagram # ") print(" ################################################# ") space3 = NNxtSpace(1, 0, GenericMLP, domain_x, sampler, layer_sizes=[64]) p3 = AnagramProjector( space3, exact, type_linesearch="armijo", data_linesearch={"n_step_max": 10, "alpha": 0.01}, svd_threshold=0.5e-4, ) new_solve = False if new_solve or not p3.load(__file__, "anagram"): p3.solve(epochs=200, n_collocation=900, verbose=True) p3.save(__file__, "anagram") plot_abstract_approx_spaces( (p2.space, p3.space), # the approximation spaces (domain_x), # the spatial domains time_domains=((t_min, t_max),), loss=(p2.losses, p3.losses), # for plot of the loss: the losses solution=(exact), # for plot of the exact sol: sol error=(exact), # for plot of the error with respect to a func: the func ) plt.show()