Source code for scimba_torch.numerical_solvers.functional_operator

"""Functional operators for PDEs with multiple labels."""

from collections import OrderedDict
from typing import Callable, Sequence

import torch

from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
    EllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
    LinearOrder2PDE,
)
from scimba_torch.physical_models.kinetic_pde.abstract_kinetic_pde import KineticPDE
from scimba_torch.physical_models.temporal_pde.abstract_temporal_pde import TemporalPDE
from scimba_torch.utils.typing_protocols import VarArgAnyCallable

# for python >= 3.12
# type func_args = VarArgCallable
# type func_func_args = Callable[[func_args, *Any], torch.Tensor]
# for python < 3.12
TYPE_THETA = dict[str, torch.nn.parameter.Parameter]
TYPE_ARGS = torch.Tensor | TYPE_THETA
TYPE_FUNC_ARGS = VarArgAnyCallable
TYPE_FUNC_FUNC_ARGS = Callable[[TYPE_FUNC_ARGS, VarArgAnyCallable], torch.Tensor]

TYPE_KEYS = int | tuple[int]

TYPE_SPLIT = OrderedDict[TYPE_KEYS, tuple[Sequence[torch.Tensor], torch.Tensor]]

TYPE_DICT_OF_FUNC_ARGS = OrderedDict[TYPE_KEYS, TYPE_FUNC_ARGS]
TYPE_DICT_OF_FUNC_FUNC_ARGS = OrderedDict[TYPE_KEYS, TYPE_FUNC_FUNC_ARGS]

TYPE_VMAPS = VarArgAnyCallable
TYPE_DICT_OF_VMAPS = OrderedDict[TYPE_KEYS, TYPE_VMAPS]


# assume all args and label have the same lenght
[docs] def split_label_tensors_to_dict_of_tensors_with_indexes( splitting_scheme: Sequence[TYPE_KEYS], labels: torch.Tensor, *args: torch.Tensor ) -> TYPE_SPLIT: """Split LabelTensors to a dict of tensors with indexes. Args: splitting_scheme: Sequence of keys to split the labels. labels: Tensor of labels to be split. *args: Tensors to be split according to the labels. Returns: A dictionary where each key from the splitting scheme maps to a tuple. The first element of the tuple is a list of tensors corresponding to the key, and the second element is a tensor of indexes. """ tensor_of_indexes = torch.arange(0, len(labels), dtype=torch.long) masks = tuple( ( torch.isin(labels, torch.tensor(el, dtype=torch.long)) if isinstance(el, tuple) else labels == el ) for el in splitting_scheme ) # print("masks: ", masks) split_indexes = tuple(tensor_of_indexes[mask] for mask in masks) # print("split_indexes: ", split_indexes) res: TYPE_SPLIT = OrderedDict() for i, el in enumerate(splitting_scheme): res[el] = ( [torch.index_select(arg, 0, split_indexes[i]) for arg in args], split_indexes[i], ) return res
# assume all tensors have the same shape except in the dimension dim
[docs] def cat_dict_of_tensors_with_indexes(input: TYPE_SPLIT) -> list[torch.Tensor]: """Concatenate a dict of tensors with indexes. Args: input: A dictionary where each key maps to a tuple containing a list of tensors and a tensor of indexes. Returns: A list of concatenated tensors. """ dim = 0 ndim = sum(input[key][0][0].shape[dim] for key in input) fkey = list(input.keys())[0] nshapes = [(ndim,) + t.shape[1:] for t in input[fkey][0]] # print("nshapes: ", nshapes) res = [torch.zeros(nshape, dtype=torch.get_default_dtype()) for nshape in nshapes] for key in input: for i, t in enumerate(input[key][0]): res[i][input[key][1], ...] = t.transpose(0, dim) return res
[docs] def compose_funcs(func1: TYPE_FUNC_FUNC_ARGS, func2: TYPE_FUNC_ARGS) -> TYPE_FUNC_ARGS: """Compose two functions. Args: func1: A function that takes another function and additional arguments. func2: A function to be passed as an argument to func1. Returns: A new function that represents the composition of func1 and func2. """ def composed_func(*args): return func1(func2, *args) return composed_func
[docs] def apply_dict_of_func_to_func( dict_of_func: TYPE_DICT_OF_FUNC_FUNC_ARGS, func: TYPE_FUNC_ARGS ) -> TYPE_DICT_OF_FUNC_ARGS: """Apply a dict of functions to a function. Args: dict_of_func: A dictionary where each key maps to a function that takes another function and additional arguments. func: A function to be passed as an argument to each function in the dictionary. Returns: A dictionary where each key maps to the result of composing the corresponding """ res: TYPE_DICT_OF_FUNC_ARGS = OrderedDict() for key in dict_of_func: res[key] = compose_funcs(dict_of_func[key], func) return res
[docs] def apply_func_to_dict_of_func( func: TYPE_FUNC_FUNC_ARGS, dict_of_func: TYPE_DICT_OF_FUNC_ARGS ) -> TYPE_DICT_OF_FUNC_ARGS: """Apply a function to a dict of functions. Args: func: A function that takes another function and additional arguments. dict_of_func: A dictionary where each key maps to a function to be passed as an argument to func. Returns: A dictionary where each key maps to the result of composing func with the corresponding function in dict_of_func. """ res: TYPE_DICT_OF_FUNC_ARGS = OrderedDict() for key in dict_of_func: res[key] = compose_funcs(func, dict_of_func[key]) return res
[docs] def vectorize_dict_of_func( vectorizing_scheme: tuple[int | None, ...], dict_of_func: TYPE_DICT_OF_FUNC_ARGS, ) -> TYPE_DICT_OF_VMAPS: """Vectorize a dict of functions. Args: vectorizing_scheme: A tuple specifying the vectorization scheme. dict_of_func: A dictionary where each key maps to a function to be vectorized. Returns: A dictionary where each key maps to the vectorized version of the corresponding function in dict_of_func. """ res: TYPE_DICT_OF_VMAPS = OrderedDict() for key in dict_of_func: res[key] = torch.func.vmap(dict_of_func[key], vectorizing_scheme) return res
[docs] def apply_dict_of_vmap_to_label_tensors( splitting_scheme: Sequence[TYPE_KEYS], vmdict: TYPE_DICT_OF_VMAPS, theta: TYPE_THETA, labels: torch.Tensor, *args: torch.Tensor, ) -> torch.Tensor: """Apply a dict of vmaps to LabelTensors. Args: splitting_scheme: Sequence of keys to split the labels. vmdict: A dictionary where each key maps to a vectorized function. theta: A dictionary of parameters to be passed to the functions. labels: Tensor of labels to be split. *args: Tensors to be passed as arguments to the functions. Returns: The result of applying the appropriate vectorized function to the split tensors. """ if len(vmdict) == 1: for key in vmdict: return vmdict[key](*args, theta) mdict: TYPE_SPLIT = split_label_tensors_to_dict_of_tensors_with_indexes( splitting_scheme, labels, *args ) evals: TYPE_SPLIT = OrderedDict() for key in mdict: evals[key] = ( [vmdict[key](*mdict[key][0], theta)], mdict[key][1], ) cat_eval = cat_dict_of_tensors_with_indexes(evals) return cat_eval[0]
[docs] class FunctionalOperator: """Handle functional operators for PDEs with multiple labels. Args: pde: An instance of a PDE class. name: The name of the operator method in the PDE class. **kwargs: Additional keyword arguments. Raises: AttributeError: If the specified operator method does not exist in the PDE class or is not callable. """ def __init__( self, pde: EllipticPDE | TemporalPDE | KineticPDE | LinearOrder2PDE, name: str, **kwargs, ): if not (hasattr(pde, name)): raise AttributeError("input PDE must have an attribute %s" % name) assert hasattr(pde, name) self.operator = getattr(pde, name) if not callable(self.operator): raise AttributeError("attribute %s of input pde must be a method" % name) self.dict_of_operators: TYPE_DICT_OF_FUNC_FUNC_ARGS = OrderedDict() try: self.dict_of_operators = self.operator() # print("self.dict_of_operators: ", self.dict_of_operators) except TypeError: # in case where only one function is given, the key does not matter: # the same operator will be applied wathever the label self.dict_of_operators = OrderedDict([(0, self.operator)]) self.keys = [key for key in self.dict_of_operators]
[docs] def split_label_tensors( self, labels: torch.Tensor, *args: torch.Tensor ) -> TYPE_SPLIT: """Split LabelTensors to a dict of tensors with indexes. Args: labels: Tensor of labels to be split. *args: Tensors to be split according to the labels. Returns: A dictionary where each key from the operator keys maps to a tuple. The first element of the tuple is a list of tensors corresponding to the key, and the second element is a tensor of indexes. """ return split_label_tensors_to_dict_of_tensors_with_indexes( self.keys, labels, *args )
[docs] def cat_dict_of_tensors(self, input: TYPE_SPLIT) -> list[torch.Tensor]: """Concatenate a dict of tensors with indexes. Args: input: A dictionary where each key maps to a tuple containing a list of tensors and a tensor of indexes. Returns: A list of concatenated tensors. """ return cat_dict_of_tensors_with_indexes(input)
[docs] def apply_dict_of_vmap_to_label_tensors( self, vmdict: TYPE_DICT_OF_VMAPS, theta: TYPE_THETA, labels: torch.Tensor, *args: torch.Tensor, ) -> torch.Tensor: """Apply a dict of vmaps to LabelTensors. Args: vmdict: A dictionary where each key maps to a vectorized function. theta: A dictionary of parameters to be passed to the functions. labels: Tensor of labels to be split. *args: Tensors to be passed as arguments to the functions. Returns: The result of applying the appropriate vectorized function to the split tensors. """ return apply_dict_of_vmap_to_label_tensors( self.keys, vmdict, theta, labels, *args )
[docs] def apply_to_func(self, func: TYPE_FUNC_ARGS) -> TYPE_DICT_OF_FUNC_ARGS: """Apply the functional operator to a function. Args: func: A function to be passed as an argument to each function in the operator dictionary. Returns: A dictionary where each key maps to the result of composing the corresponding function in the operator dictionary with func. """ return apply_dict_of_func_to_func(self.dict_of_operators, func)
[docs] def apply_func_to_dict_of_func( self, func: TYPE_FUNC_FUNC_ARGS, dict_of_func: TYPE_DICT_OF_FUNC_ARGS ) -> TYPE_DICT_OF_FUNC_ARGS: """Apply a function to a dict of functions. Args: func: A function that takes another function and additional arguments. dict_of_func: A dictionary where each key maps to a function to be passed as an argument to func. Returns: A dictionary where each key maps to the result of composing func with the corresponding function in dict_of_func. """ return apply_func_to_dict_of_func(func, dict_of_func)
[docs] def cat_tuple_of_tensors( self, input: tuple[torch.Tensor, ...], labels: torch.Tensor, *args: torch.Tensor ) -> torch.Tensor: """Concatenate a tuple of tensors with indexes. Args: input: A tuple of tensors to be concatenated. labels: Tensor of labels to be split. *args: Tensors to be split according to the labels. Returns: A tensor resulting from concatenating the input tensors according to the split indexes derived from the labels. """ mdict = self.split_label_tensors(labels, *args) for i in range(len(input)): key = self.keys[i] mdict[key] = ([input[i]], mdict[key][1]) mdict_cat = self.cat_dict_of_tensors(mdict) return mdict_cat[0]