"""Preconditioner for pinns."""
from typing import cast
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
ACCEPTED_PDE_TYPES,
TYPE_DICT_OF_VMAPS,
FunctionalOperator,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn
# def _diagonalize(matrix: torch.Tensor) -> torch.Tensor:
# mask = torch.triu(torch.ones_like(matrix, dtype=torch.bool), diagonal=1)
# return matrix * ~mask + matrix.T * mask
[docs]
class SketchyNaturalGradientPreconditioner(MatrixPreconditionerPinn):
"""Sketchy natural gradient preconditioner.
Implements [McKay and al. 2025].
Args:
space: The approximation space.
pde: The PDE to be solved, which can be an instance of EllipticPDE,
TemporalPDE, KineticPDE, or LinearOrder2PDE.
**kwargs: Additional keyword arguments:
Keyword Args:
`tol` (:code:`float`): The threshold for deciding the size of the sketch matrix,
r in the paper. (default: 1e-13).
`single_pass` (:code:`bool`): single pass or two pass
(default False, means two pass).
[McKay and al. 2025] Near-optimal Sketchy Natural Gradients
for Physics-Informed Neural Networks, M. B. McKay, A. Kaur, C. Greif, B. Wetton
Proceedings of the 42 nd International Conference on Machine
Learning, Vancouver, Canada. PMLR 267, 2025
"""
def __init__(
self,
space: AbstractApproxSpace,
pde: ACCEPTED_PDE_TYPES,
**kwargs,
):
super().__init__(space, pde, **kwargs)
self.tol = kwargs.get("tol", 1e-13)
self.t = self.ndof
self.p = 10
# self.p = int(self.t//10)
# self.r = self.t
self.r = int(self.t // 10)
# self.r = int(self.t // 2)
# self.eps = 1e-16
self.single_pass = kwargs.get("single_pass", False)
# self.diagonalize = kwargs.get("diagonalize", False)
# print("tol: %d, t: %d, r: %d, p: %d" % (self.tol, self.t, self.r, self.p))
[docs]
def compute_preconditioning_matrix(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the preconditioning matrix using the main operator.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
Phi = self.operator.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi, theta, labels, *args
)
if len(self.in_weights) == 1: # apply the same weights to all labels
for key in self.in_weights: # dummy loop
Phi[:, :, :] *= self.in_weights[key]
else: # apply weights for each labels
for key in self.in_weights:
Phi[labels == key, :, :] *= self.in_weights[key]
return Phi
[docs]
def compute_preconditioning_matrix_bc(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the boundary condition preconditioning matrix.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The boundary condition preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
self.operator_bc = cast(FunctionalOperator, self.operator_bc)
self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc)
Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi_bc, theta, labels, *args
)
if len(self.bc_weights) == 1: # apply the same weights to all labels
for key in self.bc_weights: # dummy loop
Phi[:, :, :] *= self.bc_weights[key]
else: # apply weights for each labels
for key in self.bc_weights:
Phi[labels == key, :, :] *= self.bc_weights[key]
return Phi
[docs]
def compute_preconditioning_matrix_ic(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the initial condition preconditioning matrix.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The initial condition preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
self.operator_ic = cast(FunctionalOperator, self.operator_ic)
self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic)
Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi_ic, theta, labels, *args
)
if len(self.ic_weights) == 1: # apply the same weights to all labels
for key in self.ic_weights: # dummy loop
Phi[:, :, :] *= self.ic_weights[key]
else: # apply weights for each labels
for key in self.ic_weights:
Phi[labels == key, :, :] *= self.ic_weights[key]
return Phi
[docs]
def compute_preconditioning_matrix_dl(
self, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Computes the Gram matrix of the network for the given input tensors.
Args:
*args: Input tensors for computing the Gram matrix.
**kwargs: Additional keyword arguments.
Returns:
The computed Gram matrix.
"""
return self.space.jacobian(*args)
def _fast_gram_mul(
self,
phi: torch.Tensor,
m: torch.Tensor,
phib: torch.Tensor | None = None,
phii: torch.Tensor | None = None,
phil: list[torch.Tensor] = [],
) -> torch.Tensor:
N = phi.shape[0]
MtPhi = torch.einsum("jik,ljk->ilk", m, phi)
res = (self.in_weight / N) * torch.einsum("jik,ljk->il", phi, MtPhi)
# print("m.shape: ", m.shape)
# print("phi.shape: ", phi.shape)
if phib is not None:
# print("phib.shape: ", phib.shape)
Nb = phib.shape[0]
MtPhib = torch.einsum("jik,ljk->ilk", m, phib)
res += (self.bc_weight / Nb) * torch.einsum("jik,ljk->il", phib, MtPhib)
if phii is not None:
Ni = phii.shape[0]
MtPhii = torch.einsum("jik,ljk->ilk", m, phii)
res += (self.ic_weight / Ni) * torch.einsum("jik,ljk->il", phii, MtPhii)
# print("length of Phil: ", len(phil))
# print("length of self.dl_weights: ", len(self.dl_weights))
for index, coeff in enumerate(self.dl_weights):
Nl = phil[index].shape[0]
MtPhil = torch.einsum("jik,ljk->ilk", m, phil[index])
res += (coeff / Nl) * torch.einsum("jik,ljk->il", phil[index], MtPhil)
return res
def __call__(
self,
epoch: int,
data: tuple | dict,
grads: torch.Tensor,
res_l: tuple,
res_r: tuple,
**kwargs,
) -> torch.Tensor:
"""Apply the preconditioner to the input gradients.
Args:
epoch: Current training epoch.
data: Input data, either as a tuple or a dictionary.
grads: Gradient tensor to be preconditioned.
res_l: Left residuals.
res_r: Right residuals.
**kwargs: Additional keyword arguments.
Returns:
The preconditioned gradient tensor.
"""
Phi = self.get_preconditioning_matrix(data, **kwargs)
Phib = None
Phii = None
if self.has_bc:
Phib = self.get_preconditioning_matrix_bc(data, **kwargs)
if self.has_ic:
Phii = self.get_preconditioning_matrix_ic(data, **kwargs)
Phil = [
self.get_preconditioning_matrix_dl(data, **kwargs)
for data in self.args_for_dl
]
# print("length of Phil: ", len(Phil))
# Phi = self.compute_full_preconditioning_matrix(data, **kwargs)
# N = Phi.shape[0]
# print("p: %d, r: %d / %d" % (self.p, self.r, self.t))
# there will be a broadcast for the last dimension
M = torch.normal(
1.0,
1.0,
size=(self.t, min(self.t, self.p + self.r), 1),
device=torch.get_default_device(),
)
A = self._fast_gram_mul(Phi, M, Phib, Phii, Phil)
Q, R = torch.linalg.qr(A, mode="reduced")
# print("Q.shape: ", Q.shape)# reduced QR decomposition
Qt = torch.transpose(Q, 0, 1)
if self.single_pass:
QtM = torch.einsum("ij,jkl->ik", Qt, M)
T = (Qt @ A) @ torch.linalg.inv(QtM)
else:
# there will be a broadcast for the last dimension
T = Qt @ self._fast_gram_mul(Phi, Q[:, :, None], Phib, Phii, Phil)
# if self.diagonalize:
# T = _diagonalize(T)
Eva, S = torch.linalg.eig(T)
Eva = Eva.real
S = S.real
# print("Eva: ", Eva)
mEva = torch.max(Eva).item()
# print("mEva: ", mEva)
mask_r = torch.abs(Eva) > self.tol * mEva
# mask_r = torch.abs(Eva) > self.tol
nr = int(torch.sum(mask_r).item())
# print("nr: %d / %d" % (nr, self.t))
# if nr == self.r:
# nr*=2
self.r = nr
# if epoch == 0:
# mask_p = torch.abs(Eva) > 1e-16
# np = torch.sum(mask_p)
# self.p = np
invEva = torch.zeros_like(Eva, device=torch.get_default_device())
invEva[mask_r] = 1.0 / Eva[mask_r]
invEva = torch.diag(invEva)
U = Q @ S
preconditioned_grads = (U @ invEva @ U.T) @ grads
# if epoch==1:
# assert False
return preconditioned_grads