Source code for scimba_torch.numerical_solvers.preconditioner_pinns

"""Preconditioners for pinns."""

import warnings
from collections import OrderedDict
from typing import Callable, cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    TYPE_ARGS,
    # TYPE_DICT_OF_FUNC_ARGS,
    TYPE_DICT_OF_VMAPS,
    TYPE_FUNC_ARGS,
    # TYPE_FUNC_FUNC_ARGS,
    # TYPE_THETA,
    FunctionalOperator,
    # vectorize_dict_of_func,
)
from scimba_torch.numerical_solvers.preconditioner_solvers import (
    MatrixPreconditionerSolver,
    _mjactheta,
    # functional_operator_id
    _transpose_i_j,
)
from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
    EllipticPDE,
    # StrongForm_EllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import (
    LinearOrder2PDE,
)
from scimba_torch.physical_models.kinetic_pde.abstract_kinetic_pde import KineticPDE
from scimba_torch.physical_models.temporal_pde.abstract_temporal_pde import TemporalPDE

ACCEPTED_PDE_TYPES = EllipticPDE | TemporalPDE | KineticPDE | LinearOrder2PDE


def _element(i: int, func: TYPE_FUNC_ARGS) -> TYPE_FUNC_ARGS:
    """Extract a specific element from the output of a function.

    Args:
        i: Index of the element to extract.
        func: The function whose output element is to be extracted.

    Returns:
        A function that extracts the i-th element from the output of func.
    """
    return lambda *args: func(*args)[i, ...]


[docs] class MatrixPreconditionerPinn(MatrixPreconditionerSolver): """Matrix-based preconditioner for pinns. Args: space: The approximation space. pde: The PDE to be solved. **kwargs: Additional keyword arguments: - in_lhs_name: Name of the operator to be used in the left-hand side assembly. (default: "functional_operator") """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) in_lhs_name = kwargs.get("in_lhs_name", "functional_operator") self.operator = FunctionalOperator(self.pde, in_lhs_name) self.operator_bc: None | FunctionalOperator = None if self.has_bc: self.operator_bc = FunctionalOperator(self.pde, "functional_operator_bc") self.operator_ic: None | FunctionalOperator = None if self.has_ic: self.operator_ic = FunctionalOperator(self.pde, "functional_operator_ic")
[docs] class EnergyNaturalGradientPreconditioner(MatrixPreconditionerPinn): """Energy-based natural gradient preconditioner. Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. is_operator_linear: Whether the operator is linear (default: False). **kwargs: Additional keyword arguments: - matrix_regularization: Regularization parameter for the preconditioning matrix (default: 1e-6). """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, is_operator_linear: bool = False, **kwargs, ): super().__init__(space, pde, **kwargs) self.is_operator_linear = is_operator_linear self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6) # if self.space.nb_unknowns > 1: # raise ValueError( # "EnergyNaturalGradient preconditioner is only implemented for scalar " # "problems." # ) # self.test = self.operator_bc.apply_to_func(self.eval_func) # self.vectorized_test = self.vectorize_along_physical_variables_bc( self.test ) self.vectorized_Phi_bc: None | TYPE_DICT_OF_VMAPS = None self.vectorized_Phi_ic: None | TYPE_DICT_OF_VMAPS = None # case where grad along theta and grad along physical variables commute if self.is_operator_linear: self.linear_Phi = self.operator.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator.apply_to_func(self.eval_func), ) self.vectorized_Phi = self.vectorize_along_physical_variables( self.linear_Phi ) # self.non_linear_Phi = OrderedDict() # for key in self.operator.dict_of_operators: # self.non_linear_Phi[key] = ( # lambda *args: self.eval_and_jactheta_and_func( # self.operator.dict_of_operators[key], *args # ) # ) # self.vectorized_Phi_test = self.vectorize_along_physical_variables( # self.non_linear_Phi # ) if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.linear_Phi_bc = self.operator_bc.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_bc.apply_to_func(self.eval_func), ) self.vectorized_Phi_bc = self.vectorize_along_physical_variables_bc( self.linear_Phi_bc ) # self.non_linear_Phi_bc = OrderedDict() # for key in self.operator_bc.dict_of_operators: # self.non_linear_Phi_bc[key] = ( # lambda *args: self.eval_and_jactheta_and_func( # self.operator_bc.dict_of_operators[key], *args # ) # ) # # self.vectorized_Phi_bc_test = \ # self.vectorize_along_physical_variables_bc( # self.non_linear_Phi_bc # ) if self.has_ic: self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.linear_Phi_ic = self.operator_ic.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_ic.apply_to_func(self.eval_func), ) self.vectorized_Phi_ic = self.vectorize_along_physical_variables_ic( self.linear_Phi_ic ) # self.non_linear_Phi_ic = OrderedDict() # for key in self.operator_ic.dict_of_operators: # self.non_linear_Phi_ic[key] = ( # lambda *args: self.eval_and_jactheta_and_func( # self.operator_ic.dict_of_operators[key], *args # ) # ) # self.vectorized_Phi_ic_test = \ # self.vectorize_along_physical_variables_ic( # self.non_linear_Phi_ic # ) # case where grad along theta and grad along physical variables DO NOT commute else: self.non_linear_Phi = OrderedDict() for key in self.operator.dict_of_operators: self.non_linear_Phi[key] = ( lambda *args: self.eval_and_jactheta_and_func( self.operator.dict_of_operators[key], *args ) ) self.vectorized_Phi = self.vectorize_along_physical_variables( self.non_linear_Phi ) if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.non_linear_Phi_bc = OrderedDict() for key in self.operator_bc.dict_of_operators: self.non_linear_Phi_bc[key] = ( lambda *args: self.eval_and_jactheta_and_func( cast( FunctionalOperator, self.operator_bc ).dict_of_operators[key], *args, ) ) self.vectorized_Phi_bc = self.vectorize_along_physical_variables_bc( self.non_linear_Phi_bc ) if self.has_ic: self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.non_linear_Phi_ic = OrderedDict() for key in self.operator_ic.dict_of_operators: self.non_linear_Phi_ic[key] = ( lambda *args: self.eval_and_jactheta_and_func( cast( FunctionalOperator, self.operator_ic ).dict_of_operators[key], *args, ) ) self.vectorized_Phi_ic = self.vectorize_along_physical_variables_ic( self.non_linear_Phi_ic ) # case of a non linear operator
[docs] def eval_and_jactheta(self, *args: TYPE_ARGS) -> torch.Tensor: """Evaluate the Jacobian of the network with respect to its parameters. Args: *args: Arguments to be passed to the network, with the last argument being the parameters of the network. Returns: The Jacobian matrix of the network with respect to its parameters. """ return (_transpose_i_j(-1, -2, _mjactheta))(self.eval_func, *args)
[docs] def eval_and_jactheta_and_func( self, func: Callable, *args: TYPE_ARGS ) -> torch.Tensor: """Evaluate the Jacobian of a given function. With respect to the network parameters. Args: func: The function whose Jacobian is to be computed. *args: Arguments to be passed to the function, with the last argument being the parameters of the network. Returns: The Jacobian matrix of the function with respect to the network parameters.s """ nbparams = self.ndof return torch.stack( [func(_element(i, self.eval_and_jactheta), *args) for i in range(nbparams)], 0, )
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() Phi = self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ) # Phi_test = self.operator.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_test, theta, labels, *args )[..., None] # assert torch.allclose(Phi, Phi_test) M = torch.einsum( "ijk,ilk->jl", Phi, Phi ) / N + self.matrix_regularization * torch.eye(self.ndof) # print("M.shape: ", M.shape) return M
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ) # print("Phi.shape: ", Phi.shape) # Phi_test = self.operator_bc.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_bc_test, theta, labels, *args )[..., None] # assert torch.allclose(Phi, Phi_test) # assert torch.allclose( torch.einsum("ijk,ilk->jl", Phi, Phi),\ # torch.einsum("ij,il->jl", Phi[...,0], Phi[...,0]) \ # + torch.einsum("ij,il->jl", Phi[...,1], Phi[...,1]) ) M = torch.einsum( "ijk,ilk->jl", Phi, Phi ) / N + self.matrix_regularization * torch.eye(self.ndof) # print("M.shape: ", M.shape) return M
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ N = args[0].shape[0] theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ) # Phi_test = self.operator_ic.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_Phi_ic_test, theta, labels, *args )[..., None] # assert torch.allclose(Phi, Phi_test) M = torch.einsum( "ijk,ilk->jl", Phi, Phi ) / N + self.matrix_regularization * torch.eye(self.ndof) return M
# def assemble_left_member_bc(self, # data: tuple | dict, res_l: tuple) -> torch.Tensor: # # self.operator_bc = cast(FunctionalOperator, self.operator_bc) # # if len(self.operator_bc.dict_of_operators) == 1: # return res_l[0] # # args = self.get_args_for_operator_bc(data) # # return self.operator_bc.cat_tuple_of_tensors(res_l, args[0], args[1])
[docs] def compute_full_preconditioning_matrix( self, data: tuple | dict, **kwargs ) -> torch.Tensor: """Compute the full preconditioning matrix. Include contributions from the main operator, boundary conditions, and initial conditions. Args: data: Input data, either as a tuple or a dictionary. **kwargs: Additional keyword arguments. Returns: The full preconditioning matrix. """ M = self.get_preconditioning_matrix(data, **kwargs) if self.has_bc: Mb = self.get_preconditioning_matrix_bc(data, **kwargs) M += self.bc_weight * Mb if self.has_ic: Mi = self.get_preconditioning_matrix_ic(data, **kwargs) M += self.ic_weight * Mi return M
def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ # args = self.get_args_for_operator_bc(data) # theta = self.get_formatted_current_theta() # vals = self.operator_bc.apply_dict_of_vmap_to_LabelTensors( # self.vectorized_test, theta, args[0], *args[1:] ) # vals2 = self.assemble_left_member_bc(data, res_l[1:3]) # # print("labels: ", args[0]) # print("vals: ", vals) # print("vals2: ", vals2) # assert torch.allclose(vals, vals2) M = self.compute_full_preconditioning_matrix(data, **kwargs) preconditioned_grads = torch.linalg.lstsq(M, grads).solution return preconditioned_grads
[docs] class AnagramPreconditioner(MatrixPreconditionerPinn): """Anagram preconditioner. This preconditioner is based on the anagram method, which aims to improve convergence by transforming the problem into a more favorable form. Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. **kwargs: Additional keyword arguments: - `svd_threshold` (:code:`float`): Threshold for singular value decomposition (default: 1e-6). - `bc_weight` (:code:`float`): Weight for boundary condition contributions (default: 1.0). - `ic_weight` (:code:`float`): Weight for initial condition contributions (default: 1.0). Raises: AttributeError: If the `residual_size`, `bc_residual_size`, or `ic_residual_size` attributes are not integers. """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) self.svd_threshold = kwargs.get("svd_threshold", 1e-6) # if self.space.nb_unknowns > 1: # raise ValueError( # "Anagram preconditioner is only implemented for scalar problems." # ) self.residual_size: int = 1 self.bc_residual_size: int = 1 self.ic_residual_size: int = 1 warning_message = ( "input pde or pde.space_component does not have a %s " "attribute; 1 used instead" ) error_message = ( "attribute %s of input pde or pde.space_component must be an integer" ) name = "residual_size" if hasattr(pde, name): assert hasattr(pde, name) if not isinstance(getattr(pde, name), int): raise AttributeError(error_message % name) else: self.residual_size = getattr(pde, name) elif hasattr(pde, "space_component"): assert hasattr(pde, "space_component") if hasattr(pde.space_component, name): assert hasattr(pde.space_component, name) if not isinstance(getattr(pde.space_component, name), int): raise AttributeError(error_message % name) else: self.residual_size = getattr(pde.space_component, name) else: warnings.warn(warning_message % name, UserWarning) else: warnings.warn(warning_message % name, UserWarning) if self.has_bc: name = "bc_residual_size" if hasattr(pde, name): assert hasattr(pde, name) if not isinstance(getattr(pde, name), int): raise AttributeError(error_message % name) else: self.bc_residual_size = getattr(pde, name) elif hasattr(pde, "space_component"): assert hasattr(pde, "space_component") if hasattr(pde.space_component, name): assert hasattr(pde.space_component, name) if not isinstance(getattr(pde.space_component, name), int): raise AttributeError(error_message % name) else: self.bc_residual_size = getattr(pde.space_component, name) else: warnings.warn(warning_message % name, UserWarning) else: warnings.warn(warning_message % name, UserWarning) if self.has_ic: name = "ic_residual_size" if hasattr(pde, name): assert hasattr(pde, name) if not isinstance(getattr(pde, name), int): raise AttributeError(error_message % name) else: self.ic_residual_size = getattr(pde, name) else: warnings.warn(warning_message % name, UserWarning) self.Phi = self.operator.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator.apply_to_func(self.eval_func), ) self.vectorized_Phi = self.vectorize_along_physical_variables(self.Phi) self.vectorized_Phi_bc: None | TYPE_DICT_OF_VMAPS = None self.vectorized_Phi_ic: None | TYPE_DICT_OF_VMAPS = None if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.Phi_bc = self.operator_bc.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_bc.apply_to_func(self.eval_func), ) self.vectorized_Phi_bc = self.vectorize_along_physical_variables_bc( self.Phi_bc ) if self.has_ic: self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.Phi_ic = self.operator_ic.apply_func_to_dict_of_func( _transpose_i_j(-1, -2, _mjactheta), self.operator_ic.apply_to_func(self.eval_func), ) self.vectorized_Phi_ic = self.vectorize_along_physical_variables_ic( self.Phi_ic )
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ theta = self.get_formatted_current_theta() return self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ).squeeze()
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) return self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ).squeeze()
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) return self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ).squeeze()
[docs] def assemble_right_member( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side of the equation. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor. """ in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.residual_size // len(self.operator.dict_of_operators) # print("nb_comp_per_label: ", nb_comp_per_label) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) args = self.get_args_for_operator(data) out_tup = tuple( self.operator.cat_tuple_of_tensors(res, args[0], args[1]) for res in in_tup ) # return self.operator.cat_tuple_of_tensors(in_tup, args[0], args[1]) return torch.cat(out_tup, dim=0)
[docs] def assemble_right_member_bc( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side for boundary conditions. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor for boundary conditions. """ self.operator_bc = cast(FunctionalOperator, self.operator_bc) in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator_bc.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.bc_residual_size // len( self.operator_bc.dict_of_operators ) # print("nb_comp_per_label: ", nb_comp_per_label) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) # print("len(in_tup): ", len(in_tup)) # print("len(in_tup[0]): ", len(in_tup[0])) args = self.get_args_for_operator_bc(data) out_tup = tuple( self.operator_bc.cat_tuple_of_tensors(res, args[0], args[1]) for res in in_tup ) # return self.operator_bc.cat_tuple_of_tensors(in_tup, args[0], args[1]) return torch.cat(out_tup, dim=0)
[docs] def assemble_right_member_ic( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side for initial conditions. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor for initial conditions. """ self.operator_ic = cast(FunctionalOperator, self.operator_ic) in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator_ic.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.ic_residual_size // len( self.operator_ic.dict_of_operators ) # print("nb_comp_per_label: ", nb_comp_per_label) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) args = self.get_args_for_operator_ic(data) out_tup = tuple( self.operator_ic.cat_tuple_of_tensors(res, args[0], args[1]) for res in in_tup ) # return self.operator_ic.cat_tuple_of_tensors(in_tup, args[0], args[1]) return torch.cat(out_tup, dim=0)
def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the Anagram preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ phi = self.get_preconditioning_matrix(data, **kwargs) # print("phi.shape: ", phi.shape) if phi.ndim > 2: phi = torch.cat( tuple(phi[:, :, i, ...] for i in range(phi.shape[2])), dim=0 ) # print("phi.shape: ", phi.shape) if self.has_bc: phib = self.get_preconditioning_matrix_bc(data, **kwargs) # print("phib.shape: ", phib.shape) if phib.ndim > 2: phib = torch.cat( tuple(phib[:, :, i, ...] for i in range(phib.shape[2])), dim=0 ) # print("phib.shape: ", phib.shape) phi = torch.cat((phi, phib), dim=0) # print("phi.shape: ", phi.shape) if self.has_ic: phii = self.get_preconditioning_matrix_ic(data, **kwargs) # print("phii.shape: ", phii.shape) if phii.ndim > 2: phii = torch.cat( tuple(phii[:, :, i, ...] for i in range(phii.shape[2])), dim=0 ) # print("phii.shape: ", phi.shape) phi = torch.cat((phi, phii), dim=0) # ### compute pseudo inverse via svd U, Delta, Vt = torch.linalg.svd( phi, full_matrices=False, ) mask = Delta > self.svd_threshold # Keep only values greater than... # print( # "nb sv : %d, kept: %d, max: %.2e, threshold: %.2e, relative: %.2e" # % ( # Delta.shape[0], # torch.sum(mask), # torch.max(Delta).item(), # self.svd_threshold, # torch.max(Delta).item() * self.svd_threshold, # ) # ) Delta_inv = torch.zeros_like(Delta) Delta_inv[mask] = 1.0 / Delta[mask] phi_plus = Vt.T @ torch.diag(Delta_inv) @ U.T # correct # nextindex = 0 begin = nextindex # length = len(self.operator.dict_of_operators) length = self.residual_size end = nextindex + length res = self.assemble_right_member(data, res_l[begin:end], res_r[begin:end]) nextindex += length # print("res.shape: ", res.shape) if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) begin = nextindex # length = len(self.operator_bc.dict_of_operators) length = self.bc_residual_size end = nextindex + length resb = self.assemble_right_member_bc( data, res_l[begin:end], res_r[begin:end] ) # print("resb.shape: ", resb.shape) res = torch.cat((res, resb), dim=0) # print("res.shape: ", res.shape) nextindex += length if self.has_ic: self.operator_ic = cast(FunctionalOperator, self.operator_ic) begin = nextindex # length = len(self.operator_ic.dict_of_operators) length = self.ic_residual_size end = nextindex + length resi = self.assemble_right_member_ic( data, res_l[begin:end], res_r[begin:end] ) res = torch.cat((res, resi), dim=0) nextindex += length preconditioned_grads = phi_plus @ res # return preconditioned_grads