"""Preconditioner for pinns."""
import warnings
from collections import OrderedDict
from typing import Callable, cast
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
ACCEPTED_PDE_TYPES,
TYPE_DICT_OF_VMAPS,
FunctionalOperator,
)
from scimba_torch.numerical_solvers.pinn_preconditioners import (
EnergyNaturalGradientPreconditioner, # for debug mode
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn
def _cg(
g: Callable, b: torch.Tensor, rtol: float, maxit: int, pginv: Callable
) -> tuple[torch.Tensor, int]:
"""Matrix-free conjugate gradient with preconditioner for Gx=b.
Implements algo 5.1 (page 10) in [Frangella and al. 2021].
Args:
g: The product matrix-vector function for G.
b: The right member.
rtol: the relative error to stop.
maxit: the max number of iterations.
pginv: The product matrix-vector function for the preconditioner,
which should approximate inverse of G.
Returns:
the approximate solution and the numer of iterations.
[Frangella and al. 2021] RANDOMIZED NYSTROM PRECONDITIONING,
Z. Frangella, J. A. Tropp, M. Udell
https://arxiv.org/abs/2110.02820
"""
with torch.no_grad(): # other the computation graph is huge!
x = torch.zeros_like(b)
r0 = b - g(x)
nb = torch.linalg.norm(b)
z0 = pginv(r0)
p0 = z0
it = 0
while (it <= maxit) and (torch.linalg.norm(r0) > nb * rtol):
v = g(p0)
alpha = torch.dot(r0, z0) / torch.dot(p0, v)
x += alpha * p0
r = r0 - alpha * v
z = pginv(r)
beta = torch.dot(r, z) / torch.dot(r0, z0)
p0 = z + beta * p0
r0, z0 = r, z
it += 1
# print("it: ", it)
# if it <= maxit:
# return x, 0
# else:
# return x, it
return x, it
[docs]
class NystromNaturalGradientPreconditioner(MatrixPreconditionerPinn):
r"""Randomized matrix-free natural gradient preconditioner.
Implements [Bioli and al. 2025]
Args:
space: The approximation space.
pde: The PDE to be solved, which can be an instance of EllipticPDE,
TemporalPDE, KineticPDE, or LinearOrder2PDE.
**kwargs: Additional keyword arguments:
Keyword Args:
`matrix_free` (:code:`bool`): Use jvp and vjp to compute the matrix vector
product function instead of \Phi matrix. (default: False).
`eps` (:code:`float`): The eps for adaptive matrix regularization.
(default: :code:`torch.finfo(torch.get_default_dtype()).eps`).
`debug` (:code:`bool`): Debugging mode - check that gram matrices
are coherent. (default: False).
Raises:
NotImplementedError: data loss and matrix free option not implemented
[Bioli and al. 2025] Accelerating Natural Gradient Descent for PINNs with
Randomized Nyström Preconditioning, I. Bioli, C. Marcati, G. Sangalli
https://arxiv.org/abs/2505.11638v3
"""
def __init__(
self,
space: AbstractApproxSpace,
pde: ACCEPTED_PDE_TYPES,
**kwargs,
):
super().__init__(space, pde, **kwargs)
self.p = self.ndof
default_eps = torch.finfo(torch.get_default_dtype()).eps
self.eps = kwargs.get("eps", default_eps)
self.ell = 10
self.matrix_free = kwargs.get("matrix_free", False)
self.debug = kwargs.get("debug", False)
if self.matrix_free and (len(self.dl_weights) > 0):
raise NotImplementedError(
"Nyström Natural Gradient preconditioning "
"in matrix free mode with data loss is not yet implemented"
)
self.F = self.operator.apply_to_func(self.eval_func)
self.vectorized_F = self.vectorize_along_physical_variables(self.F)
if self.has_bc:
self.operator_bc = cast(FunctionalOperator, self.operator_bc)
self.F_bc = self.operator_bc.apply_to_func(self.eval_func)
self.vectorized_F_bc = self.vectorize_along_physical_variables_bc(self.F_bc)
if self.has_ic:
self.operator_ic = cast(FunctionalOperator, self.operator_ic)
self.F_ic = self.operator_ic.apply_to_func(self.eval_func)
self.vectorized_F_ic = self.vectorize_along_physical_variables_ic(self.F_ic)
self.ENGPrec: None | EnergyNaturalGradientPreconditioner = None
if self.debug:
warnings.warn(
"debug mode in NystromNaturalGradientPreconditioner; "
"might be both time and memory consuming",
UserWarning,
)
self.ENGPrec = EnergyNaturalGradientPreconditioner(
space, pde, matrix_regularization=0.0, **kwargs
)
self.sq_in_weights = OrderedDict(
[(key, self.in_weights[key] ** 2) for key in self.in_weights]
)
self.sq_bc_weights = OrderedDict()
if self.has_bc:
self.sq_bc_weights = OrderedDict(
[(key, self.bc_weights[key] ** 2) for key in self.bc_weights]
)
self.sq_ic_weights = OrderedDict()
if self.has_ic:
self.sq_ic_weights = OrderedDict(
[(key, self.ic_weights[key] ** 2) for key in self.ic_weights]
)
[docs]
def compute_preconditioning_matrix(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the preconditioning matrix using the main operator.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
Phi = self.operator.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi, theta, labels, *args
)
if len(self.in_weights) == 1: # apply the same weights to all labels
for key in self.in_weights: # dummy loop
Phi[:, :, :] *= self.in_weights[key]
else: # apply weights for each labels
for key in self.in_weights:
Phi[labels == key, :, :] *= self.in_weights[key]
return Phi
[docs]
def compute_preconditioning_matrix_bc(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the boundary condition preconditioning matrix.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The boundary condition preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
self.operator_bc = cast(FunctionalOperator, self.operator_bc)
self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc)
Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi_bc, theta, labels, *args
)
if len(self.bc_weights) == 1: # apply the same weights to all labels
for key in self.bc_weights: # dummy loop
Phi[:, :, :] *= self.bc_weights[key]
else: # apply weights for each labels
for key in self.bc_weights:
Phi[labels == key, :, :] *= self.bc_weights[key]
return Phi
[docs]
def compute_preconditioning_matrix_ic(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the initial condition preconditioning matrix.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The initial condition preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
self.operator_ic = cast(FunctionalOperator, self.operator_ic)
self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic)
Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi_ic, theta, labels, *args
)
if len(self.ic_weights) == 1: # apply the same weights to all labels
for key in self.ic_weights: # dummy loop
Phi[:, :, :] *= self.ic_weights[key]
else: # apply weights for each labels
for key in self.ic_weights:
Phi[labels == key, :, :] *= self.ic_weights[key]
return Phi
[docs]
def compute_preconditioning_matrix_dl(
self, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Computes the Gram matrix of the network for the given input tensors.
Args:
*args: Input tensors for computing the Gram matrix.
**kwargs: Additional keyword arguments.
Returns:
The computed Gram matrix.
"""
return self.space.jacobian(*args)
def _randomized_nystrom_approximation(
self, func: Callable, rank: int
) -> tuple[torch.Tensor, torch.Tensor]:
Omega_ = torch.normal(
0.0,
1.0,
size=(self.p, rank),
device=torch.get_default_device(),
)
Omega, _ = torch.linalg.qr(Omega_, mode="reduced") # reduced QR decomposition
# Y = torch.stack([func(Omega[:, i:i+1]) for i in range(rank)], dim=-1)
# print("Omega.shape: ", Omega.shape)
Y = func(Omega)
# print("Y.shape: ", Y.shape)
nu = self.eps * torch.linalg.norm(Y, "fro")
Ynu = Y + nu * Omega
try:
C = torch.linalg.cholesky(Omega.T @ Ynu)
B = torch.linalg.solve_triangular(C, Ynu.T, upper=False).T
U, Sigma, _ = torch.linalg.svd(
B, full_matrices=False
) # reduced SVD decomposition
except RuntimeError:
nu = 1e3 * nu
Ynu = Y + nu * Omega
C = torch.linalg.cholesky(Omega.T @ Ynu)
B = torch.linalg.solve_triangular(C, Ynu.T, upper=False).T
U, Sigma, _ = torch.linalg.svd(
B, full_matrices=False
) # reduced SVD decomposition
# B = torch.linalg.solve_triangular(C, Ynu.T, upper=False).T
# U, Sigma, _ = torch.linalg.svd(
# B, full_matrices=False
# ) # reduced SVD decomposition
# Lambda = torch.where(Sigma**2-nu > 0., Sigma**2-nu, 0.)
Lambda = torch.clamp(Sigma**2 - nu, min=0.0)
return U, Lambda
def _unflatten(self, vector: torch.Tensor) -> dict:
"""Unflatten a 1D vector as model's parameters structure.
Args:
vector: input vector.
Returns:
dict: unflattened vector as dictionary.
"""
start = 0
unflattened_params = {}
for name, param in self.named_parameters():
numel = param.numel()
param_data = vector[start : start + numel].reshape(param.shape)
unflattened_params[name] = param_data
start += numel
return unflattened_params
def _flatten(self, params: dict) -> torch.Tensor:
"""Flatten a dictionary of parameters into a 1D vector.
Args:
params: Dict of parameters (e.g., model's state_dict or named_parameters).
Returns:
torch.Tensor: Flattened 1D vector containing all parameter values.
"""
return torch.cat(
[params[name].reshape(-1) for name, _ in self.named_parameters()]
)
def __call__(
self,
epoch: int,
data: tuple | dict,
grads: torch.Tensor,
res_l: tuple,
res_r: tuple,
**kwargs,
) -> torch.Tensor:
"""Apply the preconditioner to the input gradients.
Args:
epoch: Current training epoch.
data: Input data, either as a tuple or a dictionary.
grads: Gradient tensor to be preconditioned.
res_l: Left residuals.
res_r: Right residuals.
**kwargs: Additional keyword arguments.
Returns:
The preconditioned gradient tensor.
"""
with torch.no_grad():
gamma = self.p
maxit = 20
kappa = 0.1
# For matrix free
# theta = self.get_formatted_current_theta()
theta_vec = self.space.get_dof(flag_scope="all", flag_format="tensor")
theta = self._unflatten(cast(torch.Tensor, theta_vec))
args = self.get_args_for_operator(data)
labels, args = args[0], args[1:]
N = labels.shape[0]
def fin_vectorized_mf(theta: dict) -> torch.Tensor:
return self.operator.apply_dict_of_vmap_to_label_tensors(
self.vectorized_F, theta, labels, *args
)
_, vjpin = torch.func.vjp(fin_vectorized_mf, theta)
def gin_vector_product_mf(v_dict: dict) -> torch.Tensor:
_, vv = torch.func.jvp(fin_vectorized_mf, (theta,), (v_dict,))
if len(self.in_weights) == 1: # apply the same weight to all labels
for key in self.in_weights: # dummy loop
vv[:, :] *= self.sq_in_weights[key]
else:
for key in self.in_weights:
vv[labels == key, :] *= self.sq_in_weights[key]
vv = self._flatten((vjpin(vv))[0])
return (self.in_weight / N) * vv
if self.has_bc:
args_bc = self.get_args_for_operator_bc(data)
labels_bc, args_bc = args_bc[0], args_bc[1:]
N_bc = labels_bc.shape[0]
def fbc_vectorized_mf(theta: dict) -> torch.Tensor:
self.operator_bc = cast(FunctionalOperator, self.operator_bc)
return self.operator_bc.apply_dict_of_vmap_to_label_tensors(
self.vectorized_F_bc, theta, labels_bc, *args_bc
)
_, vjpbc = torch.func.vjp(fbc_vectorized_mf, theta)
def gbc_vector_product_mf(v_dict: dict) -> torch.Tensor:
_, vv = torch.func.jvp(fbc_vectorized_mf, (theta,), (v_dict,))
if len(self.bc_weights) == 1: # apply the same weight to all labels
for key in self.bc_weights: # dummy loop
vv[:, :] *= self.sq_bc_weights[key]
else: # apply weights for each labels
for key in self.bc_weights:
vv[labels_bc == key, :] *= self.sq_bc_weights[key]
vv = self._flatten((vjpbc(vv))[0])
return (self.bc_weight / N_bc) * vv
if self.has_ic:
args_ic = self.get_args_for_operator_ic(data)
labels_ic, args_ic = args_ic[0], args_ic[1:]
N_ic = labels_ic.shape[0]
def fic_vectorized_mf(theta: dict) -> torch.Tensor:
self.operator_ic = cast(FunctionalOperator, self.operator_ic)
return self.operator_ic.apply_dict_of_vmap_to_label_tensors(
self.vectorized_F_ic, theta, labels_ic, *args_ic
)
_, vjpic = torch.func.vjp(fic_vectorized_mf, theta)
def gic_vector_product_mf(v_dict: dict) -> torch.Tensor:
_, vv = torch.func.jvp(fic_vectorized_mf, (theta,), (v_dict,))
if len(self.ic_weights) == 1: # apply the same weight to all labels
for key in self.ic_weights: # dummy loop
vv[:, :] *= self.sq_ic_weights[key]
else: # apply weights for each labels
for key in self.ic_weights:
vv[labels_ic == key, :] *= self.sq_ic_weights[key]
vv = self._flatten((vjpic(vv))[0])
return (self.ic_weight / N_ic) * vv
def g_vector_product_mf(v: torch.Tensor) -> torch.Tensor:
v_dict = self._unflatten(v)
res = gin_vector_product_mf(v_dict)
# print("res.shape: ", res.shape)
if self.has_bc:
res += gbc_vector_product_mf(v_dict)
if self.has_ic:
res += gic_vector_product_mf(v_dict)
return res
g_matrix_product_mf = torch.func.vmap(g_vector_product_mf, (1), (1))
Phi = torch.ones(1)
def g_vector_product_m(v: torch.Tensor) -> torch.Tensor:
Phi_v = torch.einsum("np,p->n", Phi, v)
# print("non mf vv.shape: ", Phi_v.shape)
res = torch.einsum("pn,n->p", Phi.transpose(0, 1), Phi_v)
return res
def g_matrix_product_m(v: torch.Tensor) -> torch.Tensor:
Phi_v = torch.einsum("np,pl->nl", Phi, v)
# print("non mf vv.shape: ", Phi_v.shape)
res = torch.einsum("pn,nl->pl", Phi.transpose(0, 1), Phi_v)
return res
g_vector_product = g_vector_product_mf
g_matrix_product = g_matrix_product_mf
if (not self.matrix_free) or self.debug:
Phi = self.get_preconditioning_matrix(data, **kwargs)
Phi = torch.cat(
tuple(Phi[:, :, i, ...] for i in range(Phi.shape[2])), dim=0
)
Phi *= torch.sqrt(torch.tensor(self.in_weight / N))
if self.has_bc:
Phib = self.get_preconditioning_matrix_bc(data, **kwargs)
Phib = torch.cat(
tuple(Phib[:, :, i, ...] for i in range(Phib.shape[2])), dim=0
)
Phib *= torch.sqrt(torch.tensor(self.bc_weight / N_bc))
Phi = torch.cat([Phi, Phib], dim=0)
if self.has_ic:
Phii = self.get_preconditioning_matrix_ic(data, **kwargs)
Phii = torch.cat(
tuple(Phii[:, :, i, ...] for i in range(Phii.shape[2])), dim=0
)
Phii *= torch.sqrt(torch.tensor(self.ic_weight / N_ic))
Phi = torch.cat([Phi, Phii], dim=0)
for index, coeff in enumerate(self.dl_weights):
Phil = self.get_preconditioning_matrix_dl(
self.args_for_dl[index], **kwargs
)
N_l = Phil.shape[0]
Phil = torch.cat(
tuple(Phil[:, :, i, ...] for i in range(Phil.shape[2])), dim=0
)
Phil *= torch.sqrt(torch.tensor(coeff / N_l))
Phi = torch.cat([Phi, Phil], dim=0)
if not self.matrix_free:
g_vector_product = g_vector_product_m
g_matrix_product = g_matrix_product_m
if self.debug:
assert isinstance(self.ENGPrec, EnergyNaturalGradientPreconditioner)
testv = torch.rand(
(self.p,),
device=torch.get_default_device(),
dtype=torch.get_default_dtype(),
)
assert torch.all(testv == self._flatten(self._unflatten(testv)))
testm = torch.rand(
(self.p, self.ell),
device=torch.get_default_device(),
dtype=torch.get_default_dtype(),
)
if len(self.dl_weights) == 0:
# check that matrix and matrix free version agree
assert torch.allclose(
g_vector_product_mf(testv), g_vector_product_m(testv)
)
assert torch.allclose(
g_matrix_product_mf(testm), g_matrix_product_m(testm)
)
# check correctness of non regularized energy matrix
G_ENG = self.ENGPrec.compute_full_preconditioning_matrix(data)
assert torch.allclose(g_vector_product_m(testv), G_ENG @ testv)
assert torch.allclose(g_matrix_product_m(testm), G_ENG @ testm)
# eigendecomposition randomized Nystrom approximation
U, Lambda = self._randomized_nystrom_approximation(
g_matrix_product, self.ell
)
# adaptive matrix regularization
lambda_0, lambda_l = Lambda[0].item(), Lambda[-1]
mu = gamma * self.eps * lambda_0
# inverse of preconditioner
reg_lambda_inv = torch.diag(1.0 / (Lambda + mu))
def preconditioner(v: torch.Tensor) -> torch.Tensor:
P_inv_left = U @ (reg_lambda_inv @ (U.T @ v))
P_inv_right = v - U @ (U.T @ v)
return (lambda_l + mu) * P_inv_left + P_inv_right
# stopping criterion for conjugate gradient
rel_tol = min(kappa, torch.linalg.norm(grads))
def g_vector_product_reg(v: torch.Tensor) -> torch.Tensor:
return g_vector_product(v) + mu * v
preconditioned_grads, exit_code = _cg(
g_vector_product, grads, rel_tol, maxit, preconditioner
)
# print("exit_code: %d | %d" % (exit_code, maxit))
# import psutil
# import os
# import gc
# process = psutil.Process(os.getpid())
# mem_info = process.memory_info()
# print(f"Mémoire RAM utilisée : {mem_info.rss / 1024**2:.2f} Mo")
# print("preconditioned_grads shape: ", preconditioned_grads.shape)
# adjust self.l
if lambda_l > 10 * mu:
self.ell = min(2 * self.ell, self.p)
else:
self.ell = int(torch.sum(~(Lambda < 10 * mu)).item())
# print("crit: ", ~(Lambda < 10*mu))
# print("new l: %d | %d" % (self.ell, self.p))
# print("\n\n")
return preconditioned_grads