Source code for scimba_torch.numerical_solvers.pinn_preconditioners.sketchy_ng

"""Preconditioner for pinns."""

from typing import cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    ACCEPTED_PDE_TYPES,
    TYPE_DICT_OF_VMAPS,
    FunctionalOperator,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn

# def _diagonalize(matrix: torch.Tensor) -> torch.Tensor:
#     mask = torch.triu(torch.ones_like(matrix, dtype=torch.bool), diagonal=1)
#     return matrix * ~mask + matrix.T * mask


[docs] class SketchyNaturalGradientPreconditioner(MatrixPreconditionerPinn): """Sketchy natural gradient preconditioner. Implements [McKay and al. 2025]. Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. **kwargs: Additional keyword arguments: Keyword Args: `tol` (:code:`float`): The threshold for deciding the size of the sketch matrix, r in the paper. (default: 1e-13). `single_pass` (:code:`bool`): single pass or two pass (default False, means two pass). [McKay and al. 2025] Near-optimal Sketchy Natural Gradients for Physics-Informed Neural Networks, M. B. McKay, A. Kaur, C. Greif, B. Wetton Proceedings of the 42 nd International Conference on Machine Learning, Vancouver, Canada. PMLR 267, 2025 """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) self.tol = kwargs.get("tol", 1e-13) self.t = self.ndof self.p = 10 # self.p = int(self.t//10) # self.r = self.t self.r = int(self.t // 10) # self.r = int(self.t // 2) # self.eps = 1e-16 self.single_pass = kwargs.get("single_pass", False) # self.diagonalize = kwargs.get("diagonalize", False) # print("tol: %d, t: %d, r: %d, p: %d" % (self.tol, self.t, self.r, self.p))
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ theta = self.get_formatted_current_theta() Phi = self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ) if len(self.in_weights) == 1: # apply the same weights to all labels for key in self.in_weights: # dummy loop Phi[:, :, :] *= self.in_weights[key] else: # apply weights for each labels for key in self.in_weights: Phi[labels == key, :, :] *= self.in_weights[key] return Phi
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ) if len(self.bc_weights) == 1: # apply the same weights to all labels for key in self.bc_weights: # dummy loop Phi[:, :, :] *= self.bc_weights[key] else: # apply weights for each labels for key in self.bc_weights: Phi[labels == key, :, :] *= self.bc_weights[key] return Phi
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ) if len(self.ic_weights) == 1: # apply the same weights to all labels for key in self.ic_weights: # dummy loop Phi[:, :, :] *= self.ic_weights[key] else: # apply weights for each labels for key in self.ic_weights: Phi[labels == key, :, :] *= self.ic_weights[key] return Phi
[docs] def compute_preconditioning_matrix_dl( self, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Computes the Gram matrix of the network for the given input tensors. Args: *args: Input tensors for computing the Gram matrix. **kwargs: Additional keyword arguments. Returns: The computed Gram matrix. """ return self.space.jacobian(*args)
def _fast_gram_mul( self, phi: torch.Tensor, m: torch.Tensor, phib: torch.Tensor | None = None, phii: torch.Tensor | None = None, phil: list[torch.Tensor] = [], ) -> torch.Tensor: N = phi.shape[0] MtPhi = torch.einsum("jik,ljk->ilk", m, phi) res = (self.in_weight / N) * torch.einsum("jik,ljk->il", phi, MtPhi) # print("m.shape: ", m.shape) # print("phi.shape: ", phi.shape) if phib is not None: # print("phib.shape: ", phib.shape) Nb = phib.shape[0] MtPhib = torch.einsum("jik,ljk->ilk", m, phib) res += (self.bc_weight / Nb) * torch.einsum("jik,ljk->il", phib, MtPhib) if phii is not None: Ni = phii.shape[0] MtPhii = torch.einsum("jik,ljk->ilk", m, phii) res += (self.ic_weight / Ni) * torch.einsum("jik,ljk->il", phii, MtPhii) # print("length of Phil: ", len(phil)) # print("length of self.dl_weights: ", len(self.dl_weights)) for index, coeff in enumerate(self.dl_weights): Nl = phil[index].shape[0] MtPhil = torch.einsum("jik,ljk->ilk", m, phil[index]) res += (coeff / Nl) * torch.einsum("jik,ljk->il", phil[index], MtPhil) return res def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ Phi = self.get_preconditioning_matrix(data, **kwargs) Phib = None Phii = None if self.has_bc: Phib = self.get_preconditioning_matrix_bc(data, **kwargs) if self.has_ic: Phii = self.get_preconditioning_matrix_ic(data, **kwargs) Phil = [ self.get_preconditioning_matrix_dl(data, **kwargs) for data in self.args_for_dl ] # print("length of Phil: ", len(Phil)) # Phi = self.compute_full_preconditioning_matrix(data, **kwargs) # N = Phi.shape[0] # print("p: %d, r: %d / %d" % (self.p, self.r, self.t)) # there will be a broadcast for the last dimension M = torch.normal( 1.0, 1.0, size=(self.t, min(self.t, self.p + self.r), 1), device=torch.get_default_device(), ) A = self._fast_gram_mul(Phi, M, Phib, Phii, Phil) Q, R = torch.linalg.qr(A, mode="reduced") # print("Q.shape: ", Q.shape)# reduced QR decomposition Qt = torch.transpose(Q, 0, 1) if self.single_pass: QtM = torch.einsum("ij,jkl->ik", Qt, M) T = (Qt @ A) @ torch.linalg.inv(QtM) else: # there will be a broadcast for the last dimension T = Qt @ self._fast_gram_mul(Phi, Q[:, :, None], Phib, Phii, Phil) # if self.diagonalize: # T = _diagonalize(T) Eva, S = torch.linalg.eig(T) Eva = Eva.real S = S.real # print("Eva: ", Eva) mEva = torch.max(Eva).item() # print("mEva: ", mEva) mask_r = torch.abs(Eva) > self.tol * mEva # mask_r = torch.abs(Eva) > self.tol nr = int(torch.sum(mask_r).item()) # print("nr: %d / %d" % (nr, self.t)) # if nr == self.r: # nr*=2 self.r = nr # if epoch == 0: # mask_p = torch.abs(Eva) > 1e-16 # np = torch.sum(mask_p) # self.p = np invEva = torch.zeros_like(Eva, device=torch.get_default_device()) invEva[mask_r] = 1.0 / Eva[mask_r] invEva = torch.diag(invEva) U = Q @ S preconditioned_grads = (U @ invEva @ U.T) @ grads # if epoch==1: # assert False return preconditioned_grads