Source code for scimba_torch.numerical_solvers.preconditioner_pinns

"""Preconditioners for pinns."""

import warnings
from collections import OrderedDict
from typing import Any, cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    # _is_type_keys,
    TYPE_DICT_OF_VMAPS,
    FunctionalOperator,
    # vectorize_dict_of_func,
    check_functional_operator_against_keys,
)
from scimba_torch.numerical_solvers.preconditioner_solvers import (
    MatrixPreconditionerSolver,
    _mjactheta,
    # functional_operator_id
    _transpose_i_j,
)
from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
    EllipticPDE,
    # StrongForm_EllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
    LinearOrder2PDE,
)
from scimba_torch.physical_models.kinetic_pde.abstract_kinetic_pde import KineticPDE
from scimba_torch.physical_models.temporal_pde.abstract_temporal_pde import TemporalPDE

ACCEPTED_PDE_TYPES = EllipticPDE | TemporalPDE | KineticPDE | LinearOrder2PDE


# def _element(i: int, func: TYPE_FUNC_ARGS) -> TYPE_FUNC_ARGS:
#     """Extract a specific element from the output of a function.
#
#     Args:
#         i: Index of the element to extract.
#         func: The function whose output element is to be extracted.
#
#     Returns:
#         A function that extracts the i-th element from the output of func.
#     """
#     return lambda *args: func(*args)[i, ...]


def _get_residual_size(pde: Any, bc: bool = False) -> int:
    """Utility function to get residual size.

    Args:
        pde: the pde.
        bc: get bc_residual_size instead of residual_size.

    Returns:
        bc_residual_size instead or residual_size.

    Raises:
        AttributeError: If the `residual_size`, `bc_residual_size`, or
            `ic_residual_size` attributes are not integers.
    """
    name = "bc_residual_size" if bc else "residual_size"

    warning_message = (
        "input pde or pde.space_component does not have a %s attribute; 1 used instead"
    ) % name
    error_message = (
        "attribute %s of input pde or pde.space_component must be an integer"
    ) % name

    res = 1

    if hasattr(pde, name):
        assert hasattr(pde, name)
        if not isinstance(getattr(pde, name), int):
            raise AttributeError(error_message)
        else:
            res = getattr(pde, name)
    elif hasattr(pde, "space_component"):
        assert hasattr(pde, "space_component")
        if hasattr(pde.space_component, name):
            assert hasattr(pde.space_component, name)
            if not isinstance(getattr(pde.space_component, name), int):
                raise AttributeError(error_message)
            else:
                res = getattr(pde.space_component, name)
        else:
            warnings.warn(warning_message, UserWarning)
    else:
        warnings.warn(warning_message, UserWarning)

    return res


def _get_ic_residual_size(pde: Any) -> int:
    """Utility function to get ic residual size.

    Args:
        pde: the pde.

    Returns:
        ic_residual_size

    Raises:
        AttributeError: If the `residual_size`, `bc_residual_size`, or
            `ic_residual_size` attributes are not integers.
    """
    name = "ic_residual_size"

    warning_message = ("input pde does not have a %s attribute; 1 used instead") % name
    error_message = ("attribute %s of input pde must be an integer") % name

    res = 1

    if hasattr(pde, name):
        assert hasattr(pde, name)
        if not isinstance(getattr(pde, name), int):
            raise AttributeError(error_message)
        else:
            res = getattr(pde, name)
    else:
        warnings.warn(warning_message, UserWarning)

    return res


[docs] class MatrixPreconditionerPinn(MatrixPreconditionerSolver): """Matrix-based preconditioner for pinns. Args: space: The approximation space. pde: The PDE to be solved. **kwargs: Additional keyword arguments: - in_lhs_name: Name of the operator to be used in the left-hand side assembly. (default: "functional_operator") Raises: ValueError: residual size (bc, ic) is not a multiple of the number of labels """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) in_lhs_name = kwargs.get("in_lhs_name", "functional_operator") # operator for the interior of the domain self.operator = FunctionalOperator(self.pde, in_lhs_name) # check the labels of the functional operator to avoid cryptic error message # later igd: int = 1 if self.space.type_space in ["time_space"] else 0 sampler = self.space.integrator.list_sampler[igd] allkeys: list[int] = sampler.get_list_of_labels() check_functional_operator_against_keys( self.operator, "functional_operator", allkeys ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.residual_size: int = _get_residual_size(self.pde) if self.residual_size % len(self.operator.flatten_keys): if w: # release error message if any warnings.warn(w[-1].message, w[-1].category) # raise an error raise ValueError( "residual_size (%d) must be a multiple of" " the number of integer keys (%d) of functional_operator" % ( self.residual_size, len(self.operator.flatten_keys), ) ) self.Phi = self.operator.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator.apply_to_func(self.eval_func), ) self.vectorized_Phi = self.vectorize_along_physical_variables(self.Phi) # operator for the boundaries of the geometric domain self.operator_bc: None | FunctionalOperator = None self.vectorized_Phi_bc: None | TYPE_DICT_OF_VMAPS = None self.bc_residual_size: int = 1 if self.has_bc: self.operator_bc = FunctionalOperator(self.pde, "functional_operator_bc") # check the size of the functional operator to avoid cryptic error message # later igd: int = 1 if self.space.type_space in ["time_space"] else 0 sampler = self.space.integrator.list_sampler[igd] bc_allkeys: list[int] = sampler.get_list_of_bc_labels() check_functional_operator_against_keys( self.operator_bc, "functional_operator_bc", bc_allkeys ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.bc_residual_size = _get_residual_size(self.pde, bc=True) if self.bc_residual_size % len(self.operator_bc.flatten_keys): if w: # release error message if any warnings.warn(w[-1].message, w[-1].category) # raise an error raise ValueError( "bc_residual_size (%d) must be a multiple of" " the number of integer keys (%d) of functional_operator_bc" % ( self.bc_residual_size, len(self.operator_bc.flatten_keys), ) ) self.Phi_bc = self.operator_bc.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_bc.apply_to_func(self.eval_func), ) self.vectorized_Phi_bc = self.vectorize_along_physical_variables_bc( self.Phi_bc ) # operator for the initial condition self.operator_ic: None | FunctionalOperator = None self.vectorized_Phi_ic: None | TYPE_DICT_OF_VMAPS = None self.ic_residual_size: int = 1 if self.has_ic: self.operator_ic = FunctionalOperator(self.pde, "functional_operator_ic") # check the size of the functional operator to avoid cryptic error message # later check_functional_operator_against_keys( self.operator_ic, "functional_operator_ic", allkeys ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.ic_residual_size = _get_ic_residual_size(self.pde) if self.ic_residual_size % len(self.operator_ic.flatten_keys): if w: # release error message if any warnings.warn(w[-1].message, w[-1].category) # raise an error raise ValueError( "ic_residual_size (%d) must be a multiple of" " the number of integer keys (%d) of functional_operator_ic" % ( self.ic_residual_size, len(self.operator_ic.flatten_keys), ) ) self.Phi_ic = self.operator_ic.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_ic.apply_to_func(self.eval_func), ) self.vectorized_Phi_ic = self.vectorize_along_physical_variables_ic( self.Phi_ic )
TYPE_LIST_FLOAT_OR_INT = list[float | int] TYPE_VALUES = float | int | TYPE_LIST_FLOAT_OR_INT TYPE_DICT_OF_WEIGHTS = OrderedDict[int, TYPE_VALUES] def _is_type_list_float_or_int(arg: Any): """Check if argument has type TYPE_LIST_FLOAT_OR_INT. Args: arg: argument to be type checked. Returns: True iff key has type TYPE_LIST_FLOAT_OR_INT """ return isinstance(arg, list) and all(isinstance(el, float | int) for el in arg) def _is_type_value(arg: Any): """Check if argument has type TYPE_VALUES. Args: arg: argument to be type checked. Returns: True iff key has type TYPE_VALUES """ return ( isinstance(arg, int) or isinstance(arg, float) or _is_type_list_float_or_int(arg) ) def _is_type_dict_of_weight(weight: Any): """Check if argument has type TYPE_DICT_OF_WEIGHTS. Args: weight: argument to be type checked. Returns: True iff key has type TYPE_DICT_OF_WEIGHTS """ return ( isinstance(weight, OrderedDict) and all(isinstance(key, int) for key in weight) and all(_is_type_value(weight[key]) for key in weight) ) def _check_and_format_weight_argument( weight: Any, keys: list[int], residual_size: int ) -> OrderedDict[int, torch.Tensor]: """Format weight argument. Args: weight: the weight argument. keys: the keys (flatten) of coresponding functional operator. residual_size: the total nb of (labeled) equations to weight. Returns: the formatted weight argument. Raises: ValueError: weight argument is a list that has incorrect length KeyError: the weight argument has incorrect keys TypeError: the weight argument has incorrect type """ assert residual_size % len(keys) == 0 nb_eq_per_labels: int = residual_size // len(keys) if isinstance(weight, float | int): res = OrderedDict( [ ( key, torch.sqrt( torch.tensor( [weight] * nb_eq_per_labels, dtype=torch.get_default_dtype() ) ), ) for key in keys ] ) elif isinstance(weight, list) and all( isinstance(wk, float) or isinstance(wk, int) for wk in weight ): if len(weight) == 1: res = OrderedDict( [ ( key, torch.sqrt( torch.tensor( weight * nb_eq_per_labels, dtype=torch.get_default_dtype(), ) ), ) for key in keys ] ) elif len(weight) == nb_eq_per_labels: res = OrderedDict( [ ( key, torch.sqrt( torch.tensor(weight, dtype=torch.get_default_dtype()) ), ) for key in keys ] ) elif len(weight) == residual_size: res = OrderedDict( [ (key, torch.sqrt(torch.tensor(w, dtype=torch.get_default_dtype()))) for key, w in zip(keys, weight) ] ) else: raise ValueError( "weight list must have length either 1, " " %d (the nb of equations per label)" " or %d (the total nb of equations)" % (nb_eq_per_labels, nb_eq_per_labels) ) elif _is_type_dict_of_weight(weight): keys_w = [key for key in weight] keys_equal = all(k in keys for k in keys_w) and all(k in keys_w for k in keys) if not keys_equal: raise KeyError("weight dict must have keys {keys}") res = OrderedDict() for key in weight: if isinstance(weight[key], float): res[key] = torch.sqrt( torch.tensor( [weight[key]] * nb_eq_per_labels, dtype=torch.get_default_dtype(), ) ) elif isinstance(weight[key], list): if not (len(weight[key]) == nb_eq_per_labels): raise ValueError( "the values of weight dict must be either float" " or list of %d floats" % nb_eq_per_labels ) res[key] = torch.sqrt( torch.tensor(weight[key], dtype=torch.get_default_dtype()) ) else: raise TypeError( "weight argument must be of type float," "list[float], or OrderedDict[int, float | list[float]]" ) return res
[docs] class EnergyNaturalGradientPreconditioner(MatrixPreconditionerPinn): """Energy-based natural gradient preconditioner. Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. **kwargs: Additional keyword arguments: - matrix_regularization: Regularization parameter for the preconditioning matrix (default: 1e-6). """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6) self.in_weights = _check_and_format_weight_argument( self.in_weights, self.operator.flatten_keys, self.residual_size ) if self.has_bc: assert isinstance(self.operator_bc, FunctionalOperator) # for type checking self.bc_weights = _check_and_format_weight_argument( self.bc_weights, self.operator_bc.flatten_keys, self.bc_residual_size ) # print(self.bc_weights) if self.has_ic: assert isinstance(self.operator_ic, FunctionalOperator) # for type checking self.ic_weights = _check_and_format_weight_argument( self.ic_weights, self.operator_ic.flatten_keys, self.ic_residual_size )
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() Phi = self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ) if len(self.in_weights) == 1: # apply the same weights to all labels for key in self.in_weights: # dummy loop Phi[:, :, :] *= self.in_weights[key] else: # apply weights for each labels for key in self.in_weights: Phi[labels == key, :, :] *= self.in_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.in_weight * M
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ) if len(self.bc_weights) == 1: # apply the same weights to all labels for key in self.bc_weights: # dummy loop Phi[:, :, :] *= self.bc_weights[key] else: # apply weights for each labels for key in self.bc_weights: Phi[labels == key, :, :] *= self.bc_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.bc_weight * M
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ) if len(self.ic_weights) == 1: # apply the same weights to all labels for key in self.ic_weights: # dummy loop Phi[:, :, :] *= self.ic_weights[key] else: # apply weights for each labels for key in self.ic_weights: Phi[labels == key, :, :] *= self.ic_weights[key] M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N M += self.matrix_regularization * torch.eye(self.ndof) return self.ic_weight * M
[docs] def compute_full_preconditioning_matrix( self, data: tuple | dict, **kwargs ) -> torch.Tensor: """Compute the full preconditioning matrix. Include contributions from the main operator, boundary conditions, and initial conditions. Args: data: Input data, either as a tuple or a dictionary. **kwargs: Additional keyword arguments. Returns: The full preconditioning matrix. """ M = self.get_preconditioning_matrix(data, **kwargs) if self.has_bc: M += self.get_preconditioning_matrix_bc(data, **kwargs) if self.has_ic: M += self.get_preconditioning_matrix_ic(data, **kwargs) return M
def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ M = self.compute_full_preconditioning_matrix(data, **kwargs) preconditioned_grads = torch.linalg.lstsq(M, grads).solution return preconditioned_grads
[docs] class AnagramPreconditioner(MatrixPreconditionerPinn): """Anagram preconditioner. This preconditioner is based on the anagram method, which aims to improve convergence by transforming the problem into a more favorable form. Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. **kwargs: Additional keyword arguments: - `svd_threshold` (:code:`float`): Threshold for singular value decomposition (default: 1e-6). - `bc_weight` (:code:`float`): Weight for boundary condition contributions (default: 1.0). - `ic_weight` (:code:`float`): Weight for initial condition contributions (default: 1.0). """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) self.svd_threshold = kwargs.get("svd_threshold", 1e-6) # if self.space.nb_unknowns > 1: # raise ValueError( # "Anagram preconditioner is only implemented for scalar problems." # ) # now in MatrixPreconditionerPinn # self.residual_size: int = _get_residual_size(pde) # self.bc_residual_size: int = 1 # self.ic_residual_size: int = 1 # # if self.has_bc: # self.bc_residual_size = _get_residual_size(pde, bc=True) # # if self.has_ic: # self.ic_residual_size = _get_ic_residual_size(pde)
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ theta = self.get_formatted_current_theta() return self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ).squeeze()
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) return self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ).squeeze()
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) return self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ).squeeze()
[docs] def assemble_right_member( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side of the equation. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor. """ in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.residual_size // len(self.operator.flatten_keys) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) args = self.get_args_for_operator(data) out_tup = tuple( # self.operator.cat_tuple_of_tensors(res, args[0], args[1]) self.operator.cat_tuple_of_tensors_along_flatten_keys(res, args[0], args[1]) for res in in_tup ) return torch.cat(out_tup, dim=0)
[docs] def assemble_right_member_bc( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side for boundary conditions. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor for boundary conditions. """ self.operator_bc = cast(FunctionalOperator, self.operator_bc) in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator_bc.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.bc_residual_size // len(self.operator_bc.flatten_keys) # print("nb_comp_per_label: ", nb_comp_per_label) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) # print("len(in_tup): ", len(in_tup)) # print("len(in_tup[0]): ", len(in_tup[0])) args = self.get_args_for_operator_bc(data) out_tup = tuple( # self.operator_bc.cat_tuple_of_tensors(res, args[0], args[1]) self.operator_bc.cat_tuple_of_tensors_along_flatten_keys( res, args[0], args[1] ) for res in in_tup ) return torch.cat(out_tup, dim=0)
[docs] def assemble_right_member_ic( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side for initial conditions. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor for initial conditions. """ self.operator_ic = cast(FunctionalOperator, self.operator_ic) in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator_ic.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.ic_residual_size // len(self.operator_ic.flatten_keys) # print("nb_comp_per_label: ", nb_comp_per_label) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) args = self.get_args_for_operator_ic(data) out_tup = tuple( # self.operator_ic.cat_tuple_of_tensors(res, args[0], args[1]) self.operator_ic.cat_tuple_of_tensors_along_flatten_keys( res, args[0], args[1] ) for res in in_tup ) return torch.cat(out_tup, dim=0)
def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the Anagram preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ phi = self.get_preconditioning_matrix(data, **kwargs) # print("phi.shape: ", phi.shape) if phi.ndim > 2: phi = torch.cat( tuple(phi[:, :, i, ...] for i in range(phi.shape[2])), dim=0 ) # print("phi.shape: ", phi.shape) if self.has_bc: phib = self.get_preconditioning_matrix_bc(data, **kwargs) # print("phib.shape: ", phib.shape) if phib.ndim > 2: phib = torch.cat( tuple(phib[:, :, i, ...] for i in range(phib.shape[2])), dim=0 ) # print("phib.shape: ", phib.shape) phi = torch.cat((phi, phib), dim=0) # print("phi.shape: ", phi.shape) if self.has_ic: phii = self.get_preconditioning_matrix_ic(data, **kwargs) # print("phii.shape: ", phii.shape) if phii.ndim > 2: phii = torch.cat( tuple(phii[:, :, i, ...] for i in range(phii.shape[2])), dim=0 ) # print("phii.shape: ", phi.shape) phi = torch.cat((phi, phii), dim=0) # ### compute pseudo inverse via svd U, Delta, Vt = torch.linalg.svd( phi, full_matrices=False, ) mask = Delta > self.svd_threshold # Keep only values greater than... # print( # "nb sv : %d, kept: %d, max: %.2e, threshold: %.2e, relative: %.2e" # % ( # Delta.shape[0], # torch.sum(mask), # torch.max(Delta).item(), # self.svd_threshold, # torch.max(Delta).item() * self.svd_threshold, # ) # ) Delta_inv = torch.zeros_like(Delta) Delta_inv[mask] = 1.0 / Delta[mask] phi_plus = Vt.T @ torch.diag(Delta_inv) @ U.T # correct # nextindex = 0 begin = nextindex # length = len(self.operator.dict_of_operators) length = self.residual_size end = nextindex + length res = self.assemble_right_member(data, res_l[begin:end], res_r[begin:end]) nextindex += length # print("res.shape: ", res.shape) if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) begin = nextindex # length = len(self.operator_bc.dict_of_operators) length = self.bc_residual_size end = nextindex + length resb = self.assemble_right_member_bc( data, res_l[begin:end], res_r[begin:end] ) # print("resb.shape: ", resb.shape) res = torch.cat((res, resb), dim=0) # print("res.shape: ", res.shape) nextindex += length if self.has_ic: self.operator_ic = cast(FunctionalOperator, self.operator_ic) begin = nextindex # length = len(self.operator_ic.dict_of_operators) length = self.ic_residual_size end = nextindex + length resi = self.assemble_right_member_ic( data, res_l[begin:end], res_r[begin:end] ) res = torch.cat((res, resi), dim=0) nextindex += length preconditioned_grads = phi_plus @ res # return preconditioned_grads