"""An example illustrating rsdf nn to learn a torus from a parametric curve.""" import matplotlib.pyplot as plt import torch from scimba_torch.geometry.parametric_hypersurface import ( ParametricHyperSurface, ) from scimba_torch.geometry.regularized_sdf_projectors import ( learn_regularized_sdf, ) from scimba_torch.plots.plot_regularized_sdf_projector import ( plot_regularized_sdf_projector, ) from scimba_torch.utils import Mapping torch.manual_seed(0) # define the torus function def torus_3d_function(t: torch.Tensor) -> torch.Tensor: r"""Mapping from :math:`(0, 2\pi) \times (0, 2\pi)` to the surface of the torus. Maps :math:`(0, 2\pi) \times (0, 2\pi)` to the surface of a torus of major radius :math:`R` and minor radius :math:`r`. Args: radius: The major radius of the torus. tube_radius: The minor radius of the torus. center: The center of the torus. Returns: The mapping to the surface of the torus (non-invertible). """ center = torch.tensor([0.0, 0.0, 0.0]) radius = 1.0 tube_radius = 0.5 theta, phi = t[..., 0], t[..., 1] return center + torch.stack( [ (radius + tube_radius * torch.sin(theta)) * torch.cos(phi), (radius + tube_radius * torch.sin(theta)) * torch.sin(phi), tube_radius * torch.cos(theta), ], dim=-1, ) # define the torus 3d mapping torus_3d_mapping = Mapping(2, 3, torus_3d_function) # create a parametric hypersurface torus_3d = ParametricHyperSurface( [(0.0, 2 * torch.pi), (0.0, 2 * torch.pi)], torus_3d_mapping ) # a bounding box for the hypersurface torus_3d_bb = [(-2.0, 2.0), (-2.0, 2.0), (-2.0, 2.0)] # # define a pinn to learn the sdf of the sphere and train it ginn = learn_regularized_sdf( parametric_hyper_surface=torus_3d, bounding_domain=torus_3d_bb, preconditioner="ENG", mode="new", load_from="torus_ENG", save_to="torus_ENG", layer_sizes=[20 * 4], epochs=300, n_collocation=4000, n_bc_collocation=2000, verbose=False, ) # plot the result plot_regularized_sdf_projector( ginn, n_visu=64, # number of points for the visualization draw_contours=True, n_drawn_contours=20, tol_implicit_plot=1e-2, ) plt.show()