"""Learns a low- (or high-) rank function using a spectral approximation.""" # %% import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.nn_space import NNxSpace, SeparatedNNxSpace from scimba_torch.domain.meshless_domain.domain_2d import Square2D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP from scimba_torch.numerical_solvers.collocation_projector import ( NaturalGradientProjector, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor torch.manual_seed(42) def func_test_rank_1(x: LabelTensor, mu: LabelTensor): x1, x2 = x.get_components() r = torch.sqrt(x1**2 + x2**2) theta = torch.atan(x2 / x1) r = r + 0.05 * torch.cos(4 * torch.pi * theta) sigma_theta = 0.53 sigma_r = 0.008 return ( torch.sin(theta) * torch.cos(theta) * torch.exp(-((theta - torch.pi / 4) ** 2) / (2 * sigma_theta**2.0)) * torch.exp(-((r - 0.5) ** 2) / (2 * sigma_r**2.0)) ) def func_test(x: LabelTensor, mu: LabelTensor): x1, x2 = x.get_components() r = torch.sqrt(x1**2 + x2**2) theta = torch.atan(x2 / x1) r = r + 0.05 * torch.cos(4 * torch.pi * theta) sigma_theta = 0.53 sigma_r = 0.008 return ( torch.sin(theta) * torch.cos(theta) * torch.exp(-((theta - torch.pi / 4) ** 2) / (2 * sigma_theta**2.0)) * torch.exp(-((r - 0.5) ** 2) / (2 * sigma_r**2.0)) ) + 1.1 * ( torch.sin(1.5 * theta) * torch.cos(0.5 * theta) * torch.exp(-0.5 * ((theta - torch.pi / 4) ** 2) / (2 * sigma_theta**2.0)) * torch.exp(-1.25 * ((r - 0.5) ** 2) / (2 * sigma_r**2.0)) ) def pre_processing(x, mu): x1, x2 = x.get_components() r = torch.sqrt(x1**2 + x2**2) theta = torch.atan2(x2, x1) r = r + 0.05 * torch.cos(4 * torch.pi * theta) return torch.cat([r, theta], dim=1) domain_x = Square2D([(0, 1), (0, 1)], is_main_domain=True) sampler = TensorizedSampler([DomainSampler(domain_x), UniformParametricSampler([])]) # %% space1 = NNxSpace(1, 0, GenericMLP, domain_x, sampler, layer_sizes=[18, 18]) print(space1.ndof) p1 = NaturalGradientProjector( space1, func_test, type_linesearch="logarithmic_grid", data_linesearch={"M": 10, "interval": [0.0, 2.0]}, ) p1.solve(epochs=200, n_collocation=10000, verbose=True) # %% space2 = NNxSpace( 1, 0, GenericMLP, domain_x, sampler, layer_sizes=[18, 18], pre_processing=pre_processing, ) print(space2.ndof) p2 = NaturalGradientProjector( space2, func_test, type_linesearch="logarithmic_grid", data_linesearch={"M": 10, "interval": [0.0, 2.0]}, ) p2.solve(epochs=200, n_collocation=10000, verbose=True) # %% space3 = SeparatedNNxSpace( 1, 0, 2, GenericMLP, domain_x, sampler, layer_sizes=[12, 12], pre_processing=pre_processing, ) print(space3.ndof) p3 = NaturalGradientProjector( space3, func_test, type_linesearch="logarithmic_grid", data_linesearch={"M": 10, "interval": [0.0, 2.0]}, ) p3.solve(epochs=200, n_collocation=10000, verbose=True) # %% plot_abstract_approx_spaces( (p1.space, p2.space, p3.space), # the approximation spaces (domain_x,), # the spatial domain loss=(p1.losses, p2.losses, p3.losses), # for plot of the loss: the losses solution=(func_test,), # for plot of the exact sol: sol error=(func_test,), # for plot of the error with respect to a func: the func derivatives=(["ux"],), draw_contours=True, n_drawn_contours=20, title="Projecting with Energy Natural Gradient and logarithmic linesearch", titles=( "with NNxSpace", "with NNxSpace and postprocessing", "with SeparatedNNxSpace and postprocessing", ), ) plt.show() # %%