Source code for scimba_torch.geometry.parametric_hypersurface

"""A module for parametric hypersurfaces."""

from __future__ import annotations

from typing import cast

import torch

from scimba_torch.domain.meshless_domain.base import SurfacicDomain, VolumetricDomain
from scimba_torch.domain.meshless_domain.domain_nd import HypercubeND
from scimba_torch.geometry.utils import compute_bounding_box
from scimba_torch.integration.monte_carlo import SurfacicSampler
from scimba_torch.utils.mapping import Mapping
from scimba_torch.utils.typing_protocols import FuncTypeCallable


[docs] class ParametricHyperSurface(SurfacicDomain): r"""Base class for representing a parametric hypersurface. .. math:: \{ y = \text{surface}(t) | t \in D \} where D is the parametric domain. Args: parametric_domain: The parametric domain. surface: Mapping from the parametric domain to the domain. """ def __init__( self, parametric_domain: VolumetricDomain | list[tuple[float, float]] | torch.Tensor, surface: Mapping, ): if isinstance(parametric_domain, list) or isinstance( parametric_domain, torch.Tensor ): nparametric_domain = HypercubeND(parametric_domain) super().__init__("hypersurface", nparametric_domain, surface) else: super().__init__("hypersurface", parametric_domain, surface) self.sampler = SurfacicSampler(self)
[docs] def sample(self, n: int) -> tuple[torch.Tensor, torch.Tensor]: """Sample points on the hypersurface. Args: n: the number of points to sample. Returns: A tuple of tensors, the points and the normals. """ res = self.sampler.sample(n, compute_normals=True) return cast(tuple[torch.Tensor, torch.Tensor], res)
[docs] def estimate_bounding_box( self, nb_samples: int = 2000, inflation: float = 0.1 ) -> torch.Tensor: """Estimate a bounding box for the parametric curve by sampling points on it. Args: nb_samples: the number of points to sample. inflation: the inflation factor for over-estimation. Returns: A bounding box of shape (d,2) containing all the points. """ points, _ = self.sample(nb_samples) bounding_box = compute_bounding_box(points, inflation) return bounding_box
[docs] @staticmethod def bean_2d( a: int = 3, b: int = 5, theta: float = -torch.pi / 2 ) -> ParametricHyperSurface: """Bean 2D curve. Args: a: a. b: b. theta: the rotation angle. Returns: The Bean 2D as a parametric hypersurface. """ def bean_2d_function(t: torch.Tensor) -> torch.Tensor: """The bean 2d function. Args: t: The argument. Returns: c(t). """ sin = torch.sin(t) cos = torch.cos(t) x = (sin**a + cos**b) * cos y = (sin**a + cos**b) * sin return torch.cat((x, y), dim=-1) bean_2d_mapping = Mapping(1, 2, cast(FuncTypeCallable, bean_2d_function)) bean_2d_mapping = Mapping.compose(bean_2d_mapping, Mapping.rot_2d(theta)) return ParametricHyperSurface([(0.0, 2 * torch.pi)], bean_2d_mapping)
if __name__ == "__main__": # pragma: no cover from pathlib import Path import matplotlib.pyplot as plt from scimba_torch.geometry.utils import ( read_points_normals_from_file, write_points_normals_to_file, ) bean_2d = ParametricHyperSurface.bean_2d() points, normals = bean_2d.sample(1000) # print("points: ", points) # print("normals: ", normals) points_np = points.cpu().detach().numpy() normals_np = normals.cpu().detach().numpy() plt.figure(figsize=(7, 7)) plt.scatter(points_np[:, 0], points_np[:, 1], s=1, label="BC") plt.quiver( points_np[::20, 0], points_np[::20, 1], normals_np[::20, 0], normals_np[::20, 1], color="red", label="normals", alpha=0.5, ) plt.legend() plt.show() filename = "test.xy" write_points_normals_to_file(points, normals, filename) points2, normals2 = read_points_normals_from_file(filename) assert points2.shape == points.shape assert normals2.shape == normals.shape assert torch.all(points == points2) assert torch.all(normals == normals2) filepath = Path(filename) if filepath.is_file(): filepath.unlink()