Source code for scimba_torch.numerical_solvers.pinn_preconditioners.anagram_ng

"""Preconditioner for pinns."""

from typing import cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    ACCEPTED_PDE_TYPES,
    TYPE_DICT_OF_VMAPS,
    FunctionalOperator,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn


[docs] class AnagramPreconditioner(MatrixPreconditionerPinn): """Anagram preconditioner. This preconditioner is based on the anagram method, which aims to improve convergence by transforming the problem into a more favorable form. Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. **kwargs: Additional keyword arguments: Keyword Args: `svd_threshold` (:code:`float`): Threshold for singular value decomposition (default: 1e-6). """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) self.svd_threshold = kwargs.get("svd_threshold", 1e-6) # if self.space.nb_unknowns > 1: # raise ValueError( # "Anagram preconditioner is only implemented for scalar problems." # ) # now in MatrixPreconditionerPinn # self.residual_size: int = _get_residual_size(pde) # self.bc_residual_size: int = 1 # self.ic_residual_size: int = 1 # # if self.has_bc: # self.bc_residual_size = _get_residual_size(pde, bc=True) # # if self.has_ic: # self.ic_residual_size = _get_ic_residual_size(pde)
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ theta = self.get_formatted_current_theta() return self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ).squeeze()
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) return self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ).squeeze()
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) return self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ).squeeze()
[docs] def compute_preconditioning_matrix_dl( self, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Computes the Gram matrix of the network for the given input tensors. Args: *args: Input tensors for computing the Gram matrix. **kwargs: Additional keyword arguments. Returns: The computed Gram matrix. """ return self.space.jacobian(*args).squeeze(-1)
[docs] def assemble_right_member( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side of the equation. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor. """ in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.residual_size // len(self.operator.flatten_keys) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) args = self.get_args_for_operator(data) out_tup = tuple( # self.operator.cat_tuple_of_tensors(res, args[0], args[1]) self.operator.cat_tuple_of_tensors_along_flatten_keys(res, args[0], args[1]) for res in in_tup ) return torch.cat(out_tup, dim=0)
[docs] def assemble_right_member_bc( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side for boundary conditions. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor for boundary conditions. """ self.operator_bc = cast(FunctionalOperator, self.operator_bc) in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator_bc.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.bc_residual_size // len(self.operator_bc.flatten_keys) # print("nb_comp_per_label: ", nb_comp_per_label) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) # print("len(in_tup): ", len(in_tup)) # print("len(in_tup[0]): ", len(in_tup[0])) args = self.get_args_for_operator_bc(data) out_tup = tuple( # self.operator_bc.cat_tuple_of_tensors(res, args[0], args[1]) self.operator_bc.cat_tuple_of_tensors_along_flatten_keys( res, args[0], args[1] ) for res in in_tup ) return torch.cat(out_tup, dim=0)
[docs] def assemble_right_member_ic( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assemble the right-hand side for initial conditions. Args: data: Input data, either as a tuple or a dictionary. res_l: Left residuals. res_r: Right residuals. Returns: The assembled right-hand side tensor for initial conditions. """ self.operator_ic = cast(FunctionalOperator, self.operator_ic) in_tup = tuple(left - right for left, right in zip(res_l, res_r)) if len(self.operator_ic.dict_of_operators) == 1: return torch.cat(in_tup, dim=0) # concatenate components: nb_comp_per_label = self.ic_residual_size // len(self.operator_ic.flatten_keys) # print("nb_comp_per_label: ", nb_comp_per_label) in_tup = tuple( tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label)) for j in range(nb_comp_per_label) ) args = self.get_args_for_operator_ic(data) out_tup = tuple( # self.operator_ic.cat_tuple_of_tensors(res, args[0], args[1]) self.operator_ic.cat_tuple_of_tensors_along_flatten_keys( res, args[0], args[1] ) for res in in_tup ) return torch.cat(out_tup, dim=0)
[docs] def assemble_right_member_dl( self, data: tuple | dict, res_l: tuple, res_r: tuple ) -> torch.Tensor: """Assembles the right member for the preconditioning. Args: data: The input data for assembling the right member. res_l: The left residuals. res_r: The right residuals. Returns: The assembled right member. """ in_tup = tuple(left - right for left, right in zip(res_l, res_r)) in_tup_cat = tuple( torch.cat(tuple(a[:, i : i + 1] for i in range(a.shape[1])), dim=0) for a in in_tup ) # print( (torch.cat(in_tup_cat, dim=0)).shape ) return torch.cat(in_tup_cat, dim=0)
def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the Anagram preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ phi = self.get_preconditioning_matrix(data, **kwargs) # print("phi.shape: ", phi.shape) if phi.ndim > 2: phi = torch.cat( tuple(phi[:, :, i, ...] for i in range(phi.shape[2])), dim=0 ) # print("phi.shape: ", phi.shape) if self.has_bc: phib = self.get_preconditioning_matrix_bc(data, **kwargs) # print("phib.shape: ", phib.shape) if phib.ndim > 2: phib = torch.cat( tuple(phib[:, :, i, ...] for i in range(phib.shape[2])), dim=0 ) # print("phib.shape: ", phib.shape) phi = torch.cat((phi, phib), dim=0) # print("phi.shape: ", phi.shape) if self.has_ic: phii = self.get_preconditioning_matrix_ic(data, **kwargs) # print("phii.shape: ", phii.shape) if phii.ndim > 2: phii = torch.cat( tuple(phii[:, :, i, ...] for i in range(phii.shape[2])), dim=0 ) # print("phii.shape: ", phi.shape) phi = torch.cat((phi, phii), dim=0) for index in range(self.nb_dl): phid = self.get_preconditioning_matrix_dl(self.args_for_dl[index], **kwargs) if phid.ndim > 2: phid = torch.cat( tuple(phid[:, :, i, ...] for i in range(phid.shape[2])), dim=0 ) # print("phii.shape: ", phi.shape) phi = torch.cat((phi, phid), dim=0) # ### compute pseudo inverse via svd U, Delta, Vt = torch.linalg.svd( phi, full_matrices=False, ) mask = Delta > self.svd_threshold # Keep only values greater than... # print( # "nb sv : %d, kept: %d, max: %.2e, threshold: %.2e, relative: %.2e" # % ( # Delta.shape[0], # torch.sum(mask), # torch.max(Delta).item(), # self.svd_threshold, # torch.max(Delta).item() * self.svd_threshold, # ) # ) Delta_inv = torch.zeros_like(Delta) Delta_inv[mask] = 1.0 / Delta[mask] # phi_plus = Vt.T @ torch.diag(Delta_inv) @ U.T # correct phi_plus = Vt.T @ (Delta_inv[:, None] * U.T) # correct # nextindex = 0 begin = nextindex # length = len(self.operator.dict_of_operators) length = self.residual_size end = nextindex + length res = self.assemble_right_member(data, res_l[begin:end], res_r[begin:end]) nextindex += length # print("res.shape: ", res.shape) if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) begin = nextindex # length = len(self.operator_bc.dict_of_operators) length = self.bc_residual_size end = nextindex + length resb = self.assemble_right_member_bc( data, res_l[begin:end], res_r[begin:end] ) # print("resb.shape: ", resb.shape) res = torch.cat((res, resb), dim=0) # print("res.shape: ", res.shape) nextindex += length if self.has_ic: self.operator_ic = cast(FunctionalOperator, self.operator_ic) begin = nextindex # length = len(self.operator_ic.dict_of_operators) length = self.ic_residual_size end = nextindex + length resi = self.assemble_right_member_ic( data, res_l[begin:end], res_r[begin:end] ) res = torch.cat((res, resi), dim=0) nextindex += length for index in range(self.nb_dl): begin = nextindex length = 1 end = nextindex + length resd = self.assemble_right_member_dl( data, res_l[begin:end], res_r[begin:end] ) res = torch.cat((res, resd), dim=0) nextindex += length preconditioned_grads = phi_plus @ res return preconditioned_grads