"""An example illustrating rsdf nn to learn a sphere from a file point.""" from pathlib import Path import matplotlib.pyplot as plt import torch from scimba_torch.geometry.parametric_hypersurface import ParametricHyperSurface from scimba_torch.geometry.regularized_sdf_projectors import ( learn_regularized_sdf, ) from scimba_torch.geometry.utils import write_points_normals_to_file from scimba_torch.plots.plot_regularized_sdf_projector import ( plot_regularized_sdf_projector, ) from scimba_torch.utils import Mapping torch.manual_seed(0) # define the bean 2d function def sphere_3d_function(t: torch.Tensor) -> torch.Tensor: r"""Mapping from :math:`(0, 1) \times (0, 2\pi)` to the sphere. Args: center: The center of the sphere. radius: The radius of the sphere. Returns: The mapping of the sphere (non-invertible). """ center = torch.tensor([0.0, 0.0, 0.0]) radius = 1.0 theta = torch.acos(1 - 2 * t[..., 0]) return ( center + torch.stack( [ torch.cos(theta), torch.sin(theta) * torch.cos(t[..., 1]), torch.sin(theta) * torch.sin(t[..., 1]), ], dim=-1, ) * radius ) def generate_sphere_filepoints(n: int, filename: str) -> None: # create the sphere mapping sphere_3d_mapping = Mapping(2, 3, sphere_3d_function) # create a parametric hypersurface sphere_3d = ParametricHyperSurface( [(0.0, 1.0), (0.0, 2 * torch.pi)], sphere_3d_mapping ) # generate a tuple of n points, n normals points, normals = sphere_3d.sample(n) # write the samples to the file write_points_normals_to_file(points, normals, filename) filename = "sphere.xyz" filepath = Path(filename) # if it exists, remove the file if filepath.is_file(): filepath.unlink() try: generate_sphere_filepoints(2000, filename) # Create and train the pinn from the file of points ginn = learn_regularized_sdf( points_file=filename, mode="new", epochs=200, n_collocation=4000, n_bc_collocation=2000, verbose=False, ) # plot the result plot_regularized_sdf_projector( ginn, n_visu=64, # number of points for the visualization draw_contours=True, n_drawn_contours=20, ) plt.show() finally: if filepath.is_file(): filepath.unlink()