Source code for scimba_torch.numerical_solvers.preconditioner_pinns

"""Preconditioners for pinns."""

import warnings
from collections import OrderedDict
from typing import Any

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.integration.mesh_based_quadrature import RectangleMethod
from scimba_torch.integration.monte_carlo import DomainSampler
from scimba_torch.numerical_solvers.functional_operator import (
    ACCEPTED_PDE_TYPES,
    # _is_type_keys,
    TYPE_DICT_OF_VMAPS,
    FunctionalOperator,
    # vectorize_dict_of_func,
    check_functional_operator_against_keys,
)
from scimba_torch.numerical_solvers.preconditioner_solvers import (
    MatrixPreconditionerSolver,
    # _jvp,
    _mjactheta,
    # functional_operator_id
    _transpose_i_j,
)

# def _element(i: int, func: TYPE_FUNC_ARGS) -> TYPE_FUNC_ARGS:
#     """Extract a specific element from the output of a function.
#
#     Args:
#         i: Index of the element to extract.
#         func: The function whose output element is to be extracted.
#
#     Returns:
#         A function that extracts the i-th element from the output of func.
#     """
#     return lambda *args: func(*args)[i, ...]


def _get_residual_size(pde: ACCEPTED_PDE_TYPES, bc: bool = False) -> int:
    """Utility function to get residual size.

    Args:
        pde: the pde.
        bc: get bc_residual_size instead of residual_size.

    Returns:
        bc_residual_size instead or residual_size.

    Raises:
        AttributeError: If the `residual_size`, `bc_residual_size`, or
            `ic_residual_size` attributes are not integers.
    """
    name = "bc_residual_size" if bc else "residual_size"

    warning_message = (
        "input pde or pde.space_component does not have a %s attribute; 1 used instead"
    ) % name
    error_message = (
        "attribute %s of input pde or pde.space_component must be an integer"
    ) % name

    res = 1

    if hasattr(pde, name):
        assert hasattr(pde, name)
        if not isinstance(getattr(pde, name), int):
            raise AttributeError(error_message)
        else:
            res = getattr(pde, name)
    elif hasattr(pde, "space_component"):
        assert hasattr(pde, "space_component")
        if hasattr(pde.space_component, name):
            assert hasattr(pde.space_component, name)
            if not isinstance(getattr(pde.space_component, name), int):
                raise AttributeError(error_message)
            else:
                res = getattr(pde.space_component, name)
        else:
            warnings.warn(warning_message, UserWarning)
    else:
        warnings.warn(warning_message, UserWarning)

    return res


def _get_ic_residual_size(pde: ACCEPTED_PDE_TYPES) -> int:
    """Utility function to get ic residual size.

    Args:
        pde: the pde.

    Returns:
        ic_residual_size

    Raises:
        AttributeError: If the `residual_size`, `bc_residual_size`, or
            `ic_residual_size` attributes are not integers.
    """
    name = "ic_residual_size"

    warning_message = ("input pde does not have a %s attribute; 1 used instead") % name
    error_message = ("attribute %s of input pde must be an integer") % name

    res = 1

    if hasattr(pde, name):
        assert hasattr(pde, name)
        if not isinstance(getattr(pde, name), int):
            raise AttributeError(error_message)
        else:
            res = getattr(pde, name)
    else:
        warnings.warn(warning_message, UserWarning)

    return res


TYPE_LIST_FLOAT_OR_INT = list[float | int]
TYPE_VALUES = float | int | TYPE_LIST_FLOAT_OR_INT
TYPE_DICT_OF_WEIGHTS = OrderedDict[int, TYPE_VALUES]


def _is_type_list_float_or_int(arg: Any):
    """Check if argument has type TYPE_LIST_FLOAT_OR_INT.

    Args:
        arg: argument to be type checked.

    Returns:
        True iff key has type TYPE_LIST_FLOAT_OR_INT
    """
    return isinstance(arg, list) and all(isinstance(el, float | int) for el in arg)


def _is_type_value(arg: Any):
    """Check if argument has type TYPE_VALUES.

    Args:
        arg: argument to be type checked.

    Returns:
        True iff key has type TYPE_VALUES
    """
    return (
        isinstance(arg, int)
        or isinstance(arg, float)
        or _is_type_list_float_or_int(arg)
    )


def _is_type_dict_of_weight(weight: Any):
    """Check if argument has type TYPE_DICT_OF_WEIGHTS.

    Args:
        weight: argument to be type checked.

    Returns:
        True iff key has type TYPE_DICT_OF_WEIGHTS
    """
    return (
        isinstance(weight, OrderedDict)
        and all(isinstance(key, int) for key in weight)
        and all(_is_type_value(weight[key]) for key in weight)
    )


def _check_and_format_weight_argument(
    weight: Any, keys: list[int], residual_size: int
) -> OrderedDict[int, torch.Tensor]:
    """Format weight argument.

    Args:
        weight: the weight argument.
        keys: the keys (flatten) of coresponding functional operator.
        residual_size: the total nb of (labeled) equations to weight.

    Returns:
        the formatted weight argument.

    Raises:
        ValueError: weight argument is a list that has incorrect length
        KeyError: the weight argument has incorrect keys
        TypeError: the weight argument has incorrect type
    """
    assert residual_size % len(keys) == 0
    nb_eq_per_labels: int = residual_size // len(keys)
    if isinstance(weight, float | int):
        res = OrderedDict(
            [
                (
                    key,
                    torch.sqrt(
                        torch.tensor(
                            [weight] * nb_eq_per_labels, dtype=torch.get_default_dtype()
                        )
                    ),
                )
                for key in keys
            ]
        )
    elif isinstance(weight, list) and all(
        isinstance(wk, float) or isinstance(wk, int) for wk in weight
    ):
        if len(weight) == 1:
            res = OrderedDict(
                [
                    (
                        key,
                        torch.sqrt(
                            torch.tensor(
                                weight * nb_eq_per_labels,
                                dtype=torch.get_default_dtype(),
                            )
                        ),
                    )
                    for key in keys
                ]
            )
        elif len(weight) == nb_eq_per_labels:
            res = OrderedDict(
                [
                    (
                        key,
                        torch.sqrt(
                            torch.tensor(weight, dtype=torch.get_default_dtype())
                        ),
                    )
                    for key in keys
                ]
            )
        elif len(weight) == residual_size:
            res = OrderedDict(
                [
                    (key, torch.sqrt(torch.tensor(w, dtype=torch.get_default_dtype())))
                    for key, w in zip(keys, weight)
                ]
            )
        else:
            raise ValueError(
                "weight list must have length either 1, "
                " %d (the nb of equations per label)"
                " or %d (the total nb of equations)"
                % (nb_eq_per_labels, nb_eq_per_labels)
            )
    elif _is_type_dict_of_weight(weight):
        keys_w = [key for key in weight]
        keys_equal = all(k in keys for k in keys_w) and all(k in keys_w for k in keys)
        if not keys_equal:
            raise KeyError("weight dict must have keys {keys}")
        res = OrderedDict()
        for key in weight:
            if isinstance(weight[key], float):
                res[key] = torch.sqrt(
                    torch.tensor(
                        [weight[key]] * nb_eq_per_labels,
                        dtype=torch.get_default_dtype(),
                    )
                )
            elif isinstance(weight[key], list):
                if not (len(weight[key]) == nb_eq_per_labels):
                    raise ValueError(
                        "the values of weight dict must be either float"
                        " or list of %d floats" % nb_eq_per_labels
                    )
                res[key] = torch.sqrt(
                    torch.tensor(weight[key], dtype=torch.get_default_dtype())
                )
    else:
        raise TypeError(
            "weight argument must be of type float,"
            "list[float], or OrderedDict[int, float | list[float]]"
        )

    return res


[docs] class MatrixPreconditionerPinn(MatrixPreconditionerSolver): """Matrix-based preconditioner for pinns. Args: space: The approximation space. pde: The PDE to be solved. **kwargs: Additional keyword arguments: - in_lhs_name: Name of the operator to be used in the left-hand side assembly. (default: "functional_operator") Raises: ValueError: residual size (bc, ic) is not a multiple of the number of labels """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) assert self.pde is not None in_lhs_name = kwargs.get("in_lhs_name", "functional_operator") # operator for the interior of the domain self.operator = FunctionalOperator(self.pde, in_lhs_name) # check the labels of the functional operator to avoid cryptic error message # later igd: int = 1 if self.space.type_space in ["time_space"] else 0 sampler = self.space.integrator.list_sampler[igd] assert isinstance(sampler, (DomainSampler, RectangleMethod)) allkeys: list[int] = sampler.get_list_of_labels() check_functional_operator_against_keys( self.operator, "functional_operator", allkeys ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.residual_size: int = _get_residual_size(self.pde) if self.residual_size % len(self.operator.flatten_keys): if w: # release error message if any warnings.warn(w[-1].message, w[-1].category) # raise an error raise ValueError( "residual_size (%d) must be a multiple of" " the number of integer keys (%d) of functional_operator" % ( self.residual_size, len(self.operator.flatten_keys), ) ) self.Phi = self.operator.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator.apply_to_func(self.eval_func), ) self.vectorized_Phi = self.vectorize_along_physical_variables(self.Phi) # operator for the boundaries of the geometric domain self.operator_bc: None | FunctionalOperator = None self.vectorized_Phi_bc: None | TYPE_DICT_OF_VMAPS = None self.bc_residual_size: int = 1 if self.has_bc: self.operator_bc = FunctionalOperator(self.pde, "functional_operator_bc") # check the functional operator size to avoid cryptic error message later igd_bc: int = 1 if self.space.type_space in ["time_space"] else 0 sampler_bc = self.space.integrator.list_sampler[igd_bc] assert isinstance(sampler_bc, (DomainSampler, RectangleMethod)) bc_allkeys: list[int] = sampler_bc.get_list_of_bc_labels() check_functional_operator_against_keys( self.operator_bc, "functional_operator_bc", bc_allkeys ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.bc_residual_size = _get_residual_size(self.pde, bc=True) if self.bc_residual_size % len(self.operator_bc.flatten_keys): if w: # release error message if any warnings.warn(w[-1].message, w[-1].category) # raise an error raise ValueError( "bc_residual_size (%d) must be a multiple of" " the number of integer keys (%d) of functional_operator_bc" % ( self.bc_residual_size, len(self.operator_bc.flatten_keys), ) ) self.Phi_bc = self.operator_bc.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_bc.apply_to_func(self.eval_func), ) self.vectorized_Phi_bc = self.vectorize_along_physical_variables_bc( self.Phi_bc ) # operator for the initial condition self.operator_ic: None | FunctionalOperator = None self.vectorized_Phi_ic: None | TYPE_DICT_OF_VMAPS = None self.ic_residual_size: int = 1 if self.has_ic: self.operator_ic = FunctionalOperator(self.pde, "functional_operator_ic") # check the functional operator size to avoid cryptic error message later check_functional_operator_against_keys( self.operator_ic, "functional_operator_ic", allkeys ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.ic_residual_size = _get_ic_residual_size(self.pde) if self.ic_residual_size % len(self.operator_ic.flatten_keys): if w: # release error message if any warnings.warn(w[-1].message, w[-1].category) # raise an error raise ValueError( "ic_residual_size (%d) must be a multiple of" " the number of integer keys (%d) of functional_operator_ic" % ( self.ic_residual_size, len(self.operator_ic.flatten_keys), ) ) self.Phi_ic = self.operator_ic.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_ic.apply_to_func(self.eval_func), ) self.vectorized_Phi_ic = self.vectorize_along_physical_variables_ic( self.Phi_ic ) # format weights self.in_weights = _check_and_format_weight_argument( self.in_weights, self.operator.flatten_keys, self.residual_size ) if self.has_bc: assert isinstance(self.operator_bc, FunctionalOperator) # for type checking self.bc_weights = _check_and_format_weight_argument( self.bc_weights, self.operator_bc.flatten_keys, self.bc_residual_size ) # print(self.bc_weights) if self.has_ic: assert isinstance(self.operator_ic, FunctionalOperator) # for type checking self.ic_weights = _check_and_format_weight_argument( self.ic_weights, self.operator_ic.flatten_keys, self.ic_residual_size )