"""Preconditioners for pinns."""
import warnings
from collections import OrderedDict
from typing import Any
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.integration.mesh_based_quadrature import RectangleMethod
from scimba_torch.integration.monte_carlo import DomainSampler
from scimba_torch.numerical_solvers.functional_operator import (
ACCEPTED_PDE_TYPES,
# _is_type_keys,
TYPE_DICT_OF_VMAPS,
FunctionalOperator,
# vectorize_dict_of_func,
check_functional_operator_against_keys,
)
from scimba_torch.numerical_solvers.preconditioner_solvers import (
MatrixPreconditionerSolver,
# _jvp,
_mjactheta,
# functional_operator_id
_transpose_i_j,
)
# def _element(i: int, func: TYPE_FUNC_ARGS) -> TYPE_FUNC_ARGS:
# """Extract a specific element from the output of a function.
#
# Args:
# i: Index of the element to extract.
# func: The function whose output element is to be extracted.
#
# Returns:
# A function that extracts the i-th element from the output of func.
# """
# return lambda *args: func(*args)[i, ...]
def _get_residual_size(pde: ACCEPTED_PDE_TYPES, bc: bool = False) -> int:
"""Utility function to get residual size.
Args:
pde: the pde.
bc: get bc_residual_size instead of residual_size.
Returns:
bc_residual_size instead or residual_size.
Raises:
AttributeError: If the `residual_size`, `bc_residual_size`, or
`ic_residual_size` attributes are not integers.
"""
name = "bc_residual_size" if bc else "residual_size"
warning_message = (
"input pde or pde.space_component does not have a %s attribute; 1 used instead"
) % name
error_message = (
"attribute %s of input pde or pde.space_component must be an integer"
) % name
res = 1
if hasattr(pde, name):
assert hasattr(pde, name)
if not isinstance(getattr(pde, name), int):
raise AttributeError(error_message)
else:
res = getattr(pde, name)
elif hasattr(pde, "space_component"):
assert hasattr(pde, "space_component")
if hasattr(pde.space_component, name):
assert hasattr(pde.space_component, name)
if not isinstance(getattr(pde.space_component, name), int):
raise AttributeError(error_message)
else:
res = getattr(pde.space_component, name)
else:
warnings.warn(warning_message, UserWarning)
else:
warnings.warn(warning_message, UserWarning)
return res
def _get_ic_residual_size(pde: ACCEPTED_PDE_TYPES) -> int:
"""Utility function to get ic residual size.
Args:
pde: the pde.
Returns:
ic_residual_size
Raises:
AttributeError: If the `residual_size`, `bc_residual_size`, or
`ic_residual_size` attributes are not integers.
"""
name = "ic_residual_size"
warning_message = ("input pde does not have a %s attribute; 1 used instead") % name
error_message = ("attribute %s of input pde must be an integer") % name
res = 1
if hasattr(pde, name):
assert hasattr(pde, name)
if not isinstance(getattr(pde, name), int):
raise AttributeError(error_message)
else:
res = getattr(pde, name)
else:
warnings.warn(warning_message, UserWarning)
return res
TYPE_LIST_FLOAT_OR_INT = list[float | int]
TYPE_VALUES = float | int | TYPE_LIST_FLOAT_OR_INT
TYPE_DICT_OF_WEIGHTS = OrderedDict[int, TYPE_VALUES]
def _is_type_list_float_or_int(arg: Any):
"""Check if argument has type TYPE_LIST_FLOAT_OR_INT.
Args:
arg: argument to be type checked.
Returns:
True iff key has type TYPE_LIST_FLOAT_OR_INT
"""
return isinstance(arg, list) and all(isinstance(el, float | int) for el in arg)
def _is_type_value(arg: Any):
"""Check if argument has type TYPE_VALUES.
Args:
arg: argument to be type checked.
Returns:
True iff key has type TYPE_VALUES
"""
return (
isinstance(arg, int)
or isinstance(arg, float)
or _is_type_list_float_or_int(arg)
)
def _is_type_dict_of_weight(weight: Any):
"""Check if argument has type TYPE_DICT_OF_WEIGHTS.
Args:
weight: argument to be type checked.
Returns:
True iff key has type TYPE_DICT_OF_WEIGHTS
"""
return (
isinstance(weight, OrderedDict)
and all(isinstance(key, int) for key in weight)
and all(_is_type_value(weight[key]) for key in weight)
)
def _check_and_format_weight_argument(
weight: Any, keys: list[int], residual_size: int
) -> OrderedDict[int, torch.Tensor]:
"""Format weight argument.
Args:
weight: the weight argument.
keys: the keys (flatten) of coresponding functional operator.
residual_size: the total nb of (labeled) equations to weight.
Returns:
the formatted weight argument.
Raises:
ValueError: weight argument is a list that has incorrect length
KeyError: the weight argument has incorrect keys
TypeError: the weight argument has incorrect type
"""
assert residual_size % len(keys) == 0
nb_eq_per_labels: int = residual_size // len(keys)
if isinstance(weight, float | int):
res = OrderedDict(
[
(
key,
torch.sqrt(
torch.tensor(
[weight] * nb_eq_per_labels, dtype=torch.get_default_dtype()
)
),
)
for key in keys
]
)
elif isinstance(weight, list) and all(
isinstance(wk, float) or isinstance(wk, int) for wk in weight
):
if len(weight) == 1:
res = OrderedDict(
[
(
key,
torch.sqrt(
torch.tensor(
weight * nb_eq_per_labels,
dtype=torch.get_default_dtype(),
)
),
)
for key in keys
]
)
elif len(weight) == nb_eq_per_labels:
res = OrderedDict(
[
(
key,
torch.sqrt(
torch.tensor(weight, dtype=torch.get_default_dtype())
),
)
for key in keys
]
)
elif len(weight) == residual_size:
res = OrderedDict(
[
(key, torch.sqrt(torch.tensor(w, dtype=torch.get_default_dtype())))
for key, w in zip(keys, weight)
]
)
else:
raise ValueError(
"weight list must have length either 1, "
" %d (the nb of equations per label)"
" or %d (the total nb of equations)"
% (nb_eq_per_labels, nb_eq_per_labels)
)
elif _is_type_dict_of_weight(weight):
keys_w = [key for key in weight]
keys_equal = all(k in keys for k in keys_w) and all(k in keys_w for k in keys)
if not keys_equal:
raise KeyError("weight dict must have keys {keys}")
res = OrderedDict()
for key in weight:
if isinstance(weight[key], float):
res[key] = torch.sqrt(
torch.tensor(
[weight[key]] * nb_eq_per_labels,
dtype=torch.get_default_dtype(),
)
)
elif isinstance(weight[key], list):
if not (len(weight[key]) == nb_eq_per_labels):
raise ValueError(
"the values of weight dict must be either float"
" or list of %d floats" % nb_eq_per_labels
)
res[key] = torch.sqrt(
torch.tensor(weight[key], dtype=torch.get_default_dtype())
)
else:
raise TypeError(
"weight argument must be of type float,"
"list[float], or OrderedDict[int, float | list[float]]"
)
return res
[docs]
class MatrixPreconditionerPinn(MatrixPreconditionerSolver):
"""Matrix-based preconditioner for pinns.
Args:
space: The approximation space.
pde: The PDE to be solved.
**kwargs: Additional keyword arguments:
- in_lhs_name: Name of the operator to be used in the left-hand side
assembly. (default: "functional_operator")
Raises:
ValueError: residual size (bc, ic) is not a multiple of the number of labels
"""
def __init__(
self,
space: AbstractApproxSpace,
pde: ACCEPTED_PDE_TYPES,
**kwargs,
):
super().__init__(space, pde, **kwargs)
assert self.pde is not None
in_lhs_name = kwargs.get("in_lhs_name", "functional_operator")
# operator for the interior of the domain
self.operator = FunctionalOperator(self.pde, in_lhs_name)
# check the labels of the functional operator to avoid cryptic error message
# later
igd: int = 1 if self.space.type_space in ["time_space"] else 0
sampler = self.space.integrator.list_sampler[igd]
assert isinstance(sampler, (DomainSampler, RectangleMethod))
allkeys: list[int] = sampler.get_list_of_labels()
check_functional_operator_against_keys(
self.operator, "functional_operator", allkeys
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.residual_size: int = _get_residual_size(self.pde)
if self.residual_size % len(self.operator.flatten_keys):
if w: # release error message if any
warnings.warn(w[-1].message, w[-1].category)
# raise an error
raise ValueError(
"residual_size (%d) must be a multiple of"
" the number of integer keys (%d) of functional_operator"
% (
self.residual_size,
len(self.operator.flatten_keys),
)
)
self.Phi = self.operator.apply_func_to_dict_of_func(
_transpose_i_j(-1, -2, _mjactheta),
self.operator.apply_to_func(self.eval_func),
)
self.vectorized_Phi = self.vectorize_along_physical_variables(self.Phi)
# operator for the boundaries of the geometric domain
self.operator_bc: None | FunctionalOperator = None
self.vectorized_Phi_bc: None | TYPE_DICT_OF_VMAPS = None
self.bc_residual_size: int = 1
if self.has_bc:
self.operator_bc = FunctionalOperator(self.pde, "functional_operator_bc")
# check the functional operator size to avoid cryptic error message later
igd_bc: int = 1 if self.space.type_space in ["time_space"] else 0
sampler_bc = self.space.integrator.list_sampler[igd_bc]
assert isinstance(sampler_bc, (DomainSampler, RectangleMethod))
bc_allkeys: list[int] = sampler_bc.get_list_of_bc_labels()
check_functional_operator_against_keys(
self.operator_bc, "functional_operator_bc", bc_allkeys
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.bc_residual_size = _get_residual_size(self.pde, bc=True)
if self.bc_residual_size % len(self.operator_bc.flatten_keys):
if w: # release error message if any
warnings.warn(w[-1].message, w[-1].category)
# raise an error
raise ValueError(
"bc_residual_size (%d) must be a multiple of"
" the number of integer keys (%d) of functional_operator_bc"
% (
self.bc_residual_size,
len(self.operator_bc.flatten_keys),
)
)
self.Phi_bc = self.operator_bc.apply_func_to_dict_of_func(
_transpose_i_j(-1, -2, _mjactheta),
self.operator_bc.apply_to_func(self.eval_func),
)
self.vectorized_Phi_bc = self.vectorize_along_physical_variables_bc(
self.Phi_bc
)
# operator for the initial condition
self.operator_ic: None | FunctionalOperator = None
self.vectorized_Phi_ic: None | TYPE_DICT_OF_VMAPS = None
self.ic_residual_size: int = 1
if self.has_ic:
self.operator_ic = FunctionalOperator(self.pde, "functional_operator_ic")
# check the functional operator size to avoid cryptic error message later
check_functional_operator_against_keys(
self.operator_ic, "functional_operator_ic", allkeys
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.ic_residual_size = _get_ic_residual_size(self.pde)
if self.ic_residual_size % len(self.operator_ic.flatten_keys):
if w: # release error message if any
warnings.warn(w[-1].message, w[-1].category)
# raise an error
raise ValueError(
"ic_residual_size (%d) must be a multiple of"
" the number of integer keys (%d) of functional_operator_ic"
% (
self.ic_residual_size,
len(self.operator_ic.flatten_keys),
)
)
self.Phi_ic = self.operator_ic.apply_func_to_dict_of_func(
_transpose_i_j(-1, -2, _mjactheta),
self.operator_ic.apply_to_func(self.eval_func),
)
self.vectorized_Phi_ic = self.vectorize_along_physical_variables_ic(
self.Phi_ic
)
# format weights
self.in_weights = _check_and_format_weight_argument(
self.in_weights, self.operator.flatten_keys, self.residual_size
)
if self.has_bc:
assert isinstance(self.operator_bc, FunctionalOperator) # for type checking
self.bc_weights = _check_and_format_weight_argument(
self.bc_weights, self.operator_bc.flatten_keys, self.bc_residual_size
)
# print(self.bc_weights)
if self.has_ic:
assert isinstance(self.operator_ic, FunctionalOperator) # for type checking
self.ic_weights = _check_and_format_weight_argument(
self.ic_weights, self.operator_ic.flatten_keys, self.ic_residual_size
)