Source code for scimba_torch.utils.scimba_tensors

"""Utility classes for handling tensors with associated labels in PyTorch."""

from __future__ import annotations

from typing import Sequence

import torch


[docs] class MultiLabelTensor: """A class to manage tensors with space coordinates and associated labels. Args: w: The main tensor of coordinates, expected to have shape `(batch_size, dim)`. labels: A list of tensors containing labels for filtering operations. Defaults to empty list. Raises: ValueError: If input tensor has dimension <= 1, or if w and labels have different shapes[0]. """ def __init__( self, w: torch.Tensor, labels: list[torch.Tensor] | None = [], ): self.w: torch.Tensor = w #: The main tensor representing coordinates if w.dim() <= 1: raise ValueError("can not create MultiLabelTensor from tensors of dim <= 1") #: Number of dimensions in the coordinates self.size: int = w.shape[1] #: A list of label tensors, where each tensor contains integer labels #: associated with the corresponding batch entries. self.labels: list[torch.Tensor] = [] if labels is None else labels if not all(label.shape[0] == w.shape[0] for label in self.labels): raise ValueError("w and labels must have the same shape[0]") #: The shape of the tensor `w`, useful for validation and debugging. self.shape: torch.Size = w.shape
[docs] def get_components( self, index: int | None = None ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Retrieve specific components of the tensor `w`. Args: index: The specific dimension to extract from the tensor. If `None`, all dimensions are extracted as a tuple of tensors. Returns: - If `index` is specified, a single tensor corresponding to the selected dimension. - If `index` is `None`, a tuple of tensors for all dimensions. """ if (index is None) and (not (self.size == 1)): # Return all dimensions as a tuple return tuple(self.w[:, i, None] for i in range(self.size)) else: if self.size == 1: return self.w[:, 0, None] else: return self.w[:, index, None]
[docs] def restrict_to_labels( self, component: torch.Tensor | None = None, labels: list[int] | None = [] ) -> torch.Tensor: """Filter tensor `w` (or one of its components) by a list of reference labels. Args: component: The specific component to be filtered. If `None`, `self.w`. labels: A list of integers specifying the reference labels to filter rows. If `None`, no filtering is applied. Returns: Filtered tensor based on the following logic: - If `component` and `labels` are specified, `component` filtered by input list of labels - If `component` is `None` and `labels` are specified, `self.w` filtered by input list of labels - If `component` is provided and `labels` is None, a copy of `component` - Otherwise a copy of `self.w` Raises: ValueError: If provided reference labels do not match the structure of the label tensors. """ nlabels = [] if labels is None else labels # Validate the input if len(self.labels) < len(nlabels): raise ValueError( "Provided reference labels do not match the structure of the " "label tensors." ) # Initialize a boolean mask batch_size = self.w.shape[0] if len(nlabels) == 0: mask = torch.arange(batch_size) else: mask = torch.ones(batch_size, dtype=torch.bool) for label_tensor, ref_label in zip(self.labels, nlabels): mask &= label_tensor == ref_label # Filter the tensor `w` based on the mask if component is not None: # 1D case res = component[mask, 0, None] else: res = self.w[mask, :] return res
[docs] class LabelTensor: """Class for tensors representing space coordinates. Args: x: Coordinates tensor. labels: Labels for the coordinates (e.g. labels for boundary conditions, etc.). If None, creates zero labels. Raises: ValueError: If x has dimension <= 1, or if x and labels have different shape[0]. """ def __init__( self, x: torch.Tensor, labels: torch.Tensor | None = None, ): self.x: torch.Tensor = x #: Coordinate tensor if x.dim() <= 1: raise ValueError("can not create LabelTensor from tensors of dim <= 1") self.dim: int = x.shape[1] #: Space dimension self.labels: torch.Tensor #: Labels for the coordinates if labels is not None: if not labels.shape[0] == x.shape[0]: raise ValueError( "x and labels must have the same shape[0]: has %d and %d" % (x.shape[0], labels.shape[0]) ) self.labels = labels else: self.labels = torch.zeros(x.shape[0], dtype=torch.int32) self.shape = x.shape def __getitem__(self, key: int | slice) -> LabelTensor: """Overload the getitem [] operator. Args: key: Index where you want the data. Returns: The space tensor with the data only for the key. """ if isinstance(key, int): return LabelTensor(self.x[key, None], self.labels[key, None]) else: return LabelTensor(self.x[key], self.labels[key]) def __setitem__(self, key: int | slice, value: LabelTensor) -> None: """Overload the setitem [] operator. Args: key: Index where you want to set the data. value: New values for the LabelTensor associated to the given key. """ self.x[key] = value.x self.labels[key] = value.labels
[docs] def repeat(self, repeats: int | torch.Tensor) -> LabelTensor: """Overload the repeat function. Args: repeats: The size of the repeat. Returns: New LabelTensor with repeated coordinates and labels. """ return LabelTensor( torch.repeat_interleave(self.x, repeats, dim=0), torch.repeat_interleave(self.labels, repeats, dim=0), )
[docs] def get_components( self, label: int | None = None ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Returns the components from the current LabelTensor. Args: label: The label of the x that the users want. If None, returns all components. Returns: The list of coordinates. If dim=1, returns single tensor, otherwise tuple of tensors. Raises: ValueError: If no coordinates with the specified label are found. """ if label is None: if self.dim == 1: return self.x[:, 0, None] else: return tuple(self.x[:, i, None] for i in range(self.dim)) else: mask = self.labels == label if mask.sum() == 0: raise ValueError(f"No coordinates with label {label}") else: if self.dim == 1: return self.x[mask, 0, None] else: return tuple(self.x[mask, i, None] for i in range(self.dim))
[docs] @staticmethod def cat(inputs: Sequence[LabelTensor]) -> LabelTensor: """Concatenate a list of LabelTensors. Args: inputs: The list of LabelTensors to concatenate. Returns: The LabelTensor which contains all the previous LabelTensors. """ return LabelTensor( torch.cat([data.x for data in inputs], dim=0), torch.cat([data.labels for data in inputs], dim=0), )
def __str__(self) -> str: """String representation of the LabelTensor. Returns: A string describing the LabelTensor. """ return f"LabelTensor:\n x = {self.x}\n labels = {self.labels}"
[docs] def detach(self): """Detach the space tensor. Returns: The LabelTensor where x is detached on CPU. """ return LabelTensor(self.x.detach(), self.labels)
def __add__(self, other: int | float | torch.Tensor | LabelTensor) -> LabelTensor: """Overload the + operator. Args: other: A value or tensor to add. Returns: The LabelTensor resulting from the addition. Raises: TypeError: If other is not a valid type. ValueError: If labels do not match when adding two LabelTensors. """ if not isinstance(other, (int, float, torch.Tensor, LabelTensor)): raise TypeError( f"Invalid type {type(other)} for element added to LabelTensor" ) if isinstance(other, LabelTensor): # assert not torch.logical_xor( # self.labels, other.labels # ).sum(), "Labels do not match" #Remi: ??? if not torch.all(self.labels == other.labels): raise ValueError("Labels do not match") return LabelTensor(self.x + other.x, self.labels) else: return LabelTensor(self.x + other, self.labels) def __sub__(self, other: int | float | torch.Tensor | LabelTensor) -> LabelTensor: """Overload the - operator. Args: other: A value or tensor to subtract. Returns: The LabelTensor resulting from the subtraction. Raises: TypeError: If other is not a valid type. ValueError: If labels do not match when subtracting two LabelTensors. """ if not isinstance(other, (int, float, torch.Tensor, LabelTensor)): raise TypeError( f"Invalid type {type(other)} for element subtracted to LabelTensor" ) if isinstance(other, LabelTensor): # assert not torch.logical_xor( # self.labels, other.labels # ).sum(), "Labels do not match" if not torch.all(self.labels == other.labels): raise ValueError("Labels do not match") return LabelTensor(self.x - other.x, self.labels) else: return self + (-other) def __mul__(self, other: int | float | torch.Tensor | LabelTensor) -> LabelTensor: """Overload the * operator. Args: other: A value or tensor to multiply. Returns: The LabelTensor resulting from the multiplication. Raises: TypeError: If other is not a valid type. ValueError: If labels do not match when multiplying two LabelTensors. """ if not isinstance(other, (int, float, torch.Tensor, LabelTensor)): raise TypeError( f"Invalid type {type(other)} for element multiplied to LabelTensor" ) if isinstance(other, LabelTensor): # assert not torch.logical_xor( # self.labels, other.labels # ).sum(), "Labels do not match" if not torch.all(self.labels == other.labels): raise ValueError("Labels do not match") return LabelTensor(self.x * other.x, self.labels) else: return LabelTensor(self.x * other, self.labels) def __rmul__(self, other: int | float | torch.Tensor | LabelTensor) -> LabelTensor: """Overload the right * operator. Args: other: A value or tensor to multiply. Returns: The LabelTensor resulting from the multiplication. """ return self * other
[docs] def no_grad(self) -> LabelTensor: """Returns a LabelTensor with no grad on x. Returns: A LabelTensor with x detached from the computation graph. """ x_no_grad = self.x.clone().detach() x_no_grad.requires_grad = False return LabelTensor(x_no_grad, self.labels)
[docs] def unsqueeze(self, dim: int) -> LabelTensor: """Unsqueeze the space tensor. Args: dim: Dimension to unsqueeze. Returns: A LabelTensor with the specified dimension unsqueezed. """ return LabelTensor(self.x.unsqueeze(dim), self.labels)
[docs] def concatenate(self, other: LabelTensor, dim: int) -> LabelTensor: """Concatenate two LabelTensors along a specified dimension. Args: other: The LabelTensor to concatenate with the current instance. dim: The dimension along which to concatenate. Returns: The LabelTensor which contains the concatenation of the two LabelTensors. """ try: return LabelTensor( torch.cat((self.x, other.x), dim=dim), torch.cat((self.labels, other.labels), dim=dim), ) except IndexError: return LabelTensor( torch.cat((self.x, other.x), dim=dim), torch.stack((self.labels, other.labels), dim=dim), )