"""Learns a low- (or high-) rank function using a neural network.""" # %% import matplotlib.pyplot as plt import torch from scimba_torch.approximation_space.spectral_space import SeparatedSpectralxSpace from scimba_torch.domain.meshless_domain.domain_2d import Square2D from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler from scimba_torch.numerical_solvers.collocation_projector import ( NaturalGradientProjector, ) from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces from scimba_torch.utils.scimba_tensors import LabelTensor def func_test_rank_1(x: LabelTensor, mu: LabelTensor): x1, x2 = x.get_components() alpha = mu.get_components() return ( 1.0 * torch.sin(4.7 * torch.pi * x1) * torch.sin(1.2 * torch.pi * x2) * torch.sin(9.1 * torch.pi * alpha) ) def func_test_rank_2(x: LabelTensor, mu: LabelTensor): x1, x2 = x.get_components() alpha = mu.get_components() return ( 1.0 * torch.sin(4.7 * torch.pi * x1) * torch.sin(1.2 * torch.pi * x2) * torch.sin(9.1 * torch.pi * alpha) ) + ( 2.0 * torch.sin(2.4 * torch.pi * x1) * torch.sin(5.6 * torch.pi * x2) * torch.sin(2.3 * torch.pi * alpha) ) func_test = func_test_rank_2 # %% # torch.manual_seed(1) torch.manual_seed(12) domain_x = Square2D([(0.0, 1.0), (0.0, 1.0)], is_main_domain=True) sampler = TensorizedSampler( [DomainSampler(domain_x), UniformParametricSampler([(0.0, 1.0)])] ) space = SeparatedSpectralxSpace( 1, "sine", 6, bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)], integrator=sampler, rank=2, tensor_structure=[[2, 1], [2, 1]], ) p = NaturalGradientProjector( space, func_test, type_linesearch="armijo", data_linesearch={"n_step_max": 20, "alpha": 0.005}, ) # print(p.space.get_dof()) p.solve(epochs=150, n_collocation=5000, verbose=True) # print(p.space.get_dof()) plot_abstract_approx_spaces( (p.space), # the approximation spaces (domain_x), # the spatial domain ([(0.0, 1.0)]), # the parameter's domain loss=(p.losses), # for plot of the loss: the losses solution=(func_test), # for plot of the exact sol: sol error=(func_test), # for plot of the error with respect to a func: the func draw_contours=True, n_drawn_contours=20, parameters_values=[0.25], ) plt.show() # %%