Source code for scimba_torch.numerical_solvers.pinn_preconditioners.nystrom_ng

"""Preconditioner for pinns."""

import warnings
from collections import OrderedDict
from typing import Callable, cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
    ACCEPTED_PDE_TYPES,
    TYPE_DICT_OF_VMAPS,
    FunctionalOperator,
)
from scimba_torch.numerical_solvers.pinn_preconditioners import (
    EnergyNaturalGradientPreconditioner,  # for debug mode
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn


def _cg(
    g: Callable, b: torch.Tensor, rtol: float, maxit: int, pginv: Callable
) -> tuple[torch.Tensor, int]:
    """Matrix-free conjugate gradient with preconditioner for Gx=b.

    Implements algo 5.1 (page 10) in [Frangella and al. 2021].

    Args:
        g: The product matrix-vector function for G.
        b: The right member.
        rtol: the relative error to stop.
        maxit: the max number of iterations.
        pginv: The product matrix-vector function for the preconditioner,
            which should approximate inverse of G.

    Returns:
        the approximate solution and the numer of iterations.

    [Frangella and al. 2021] RANDOMIZED NYSTROM PRECONDITIONING,
        Z. Frangella, J. A. Tropp, M. Udell
        https://arxiv.org/abs/2110.02820
    """
    with torch.no_grad():  # other the computation graph is huge!
        x = torch.zeros_like(b)
        r0 = b - g(x)
        nb = torch.linalg.norm(b)
        z0 = pginv(r0)
        p0 = z0
        it = 0
        while (it <= maxit) and (torch.linalg.norm(r0) > nb * rtol):
            v = g(p0)
            alpha = torch.dot(r0, z0) / torch.dot(p0, v)
            x += alpha * p0
            r = r0 - alpha * v
            z = pginv(r)
            beta = torch.dot(r, z) / torch.dot(r0, z0)
            p0 = z + beta * p0
            r0, z0 = r, z
            it += 1

        # print("it: ", it)
        # if it <= maxit:
        #     return x, 0
        # else:
        #     return x, it

        return x, it


[docs] class NystromNaturalGradientPreconditioner(MatrixPreconditionerPinn): r"""Randomized matrix-free natural gradient preconditioner. Implements [Bioli and al. 2025] Args: space: The approximation space. pde: The PDE to be solved, which can be an instance of EllipticPDE, TemporalPDE, KineticPDE, or LinearOrder2PDE. **kwargs: Additional keyword arguments: Keyword Args: `matrix_free` (:code:`bool`): Use jvp and vjp to compute the matrix vector product function instead of \Phi matrix. (default: False). `eps` (:code:`float`): The eps for adaptive matrix regularization. (default: :code:`torch.finfo(torch.get_default_dtype()).eps`). `debug` (:code:`bool`): Debugging mode - check that gram matrices are coherent. (default: False). Raises: NotImplementedError: data loss and matrix free option not implemented [Bioli and al. 2025] Accelerating Natural Gradient Descent for PINNs with Randomized Nyström Preconditioning, I. Bioli, C. Marcati, G. Sangalli https://arxiv.org/abs/2505.11638v3 """ def __init__( self, space: AbstractApproxSpace, pde: ACCEPTED_PDE_TYPES, **kwargs, ): super().__init__(space, pde, **kwargs) self.p = self.ndof default_eps = torch.finfo(torch.get_default_dtype()).eps self.eps = kwargs.get("eps", default_eps) self.ell = 10 self.matrix_free = kwargs.get("matrix_free", False) self.debug = kwargs.get("debug", False) if self.matrix_free and (len(self.dl_weights) > 0): raise NotImplementedError( "Nyström Natural Gradient preconditioning " "in matrix free mode with data loss is not yet implemented" ) self.F = self.operator.apply_to_func(self.eval_func) self.vectorized_F = self.vectorize_along_physical_variables(self.F) if self.has_bc: self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.F_bc = self.operator_bc.apply_to_func(self.eval_func) self.vectorized_F_bc = self.vectorize_along_physical_variables_bc(self.F_bc) if self.has_ic: self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.F_ic = self.operator_ic.apply_to_func(self.eval_func) self.vectorized_F_ic = self.vectorize_along_physical_variables_ic(self.F_ic) self.ENGPrec: None | EnergyNaturalGradientPreconditioner = None if self.debug: warnings.warn( "debug mode in NystromNaturalGradientPreconditioner; " "might be both time and memory consuming", UserWarning, ) self.ENGPrec = EnergyNaturalGradientPreconditioner( space, pde, matrix_regularization=0.0, **kwargs ) self.sq_in_weights = OrderedDict( [(key, self.in_weights[key] ** 2) for key in self.in_weights] ) self.sq_bc_weights = OrderedDict() if self.has_bc: self.sq_bc_weights = OrderedDict( [(key, self.bc_weights[key] ** 2) for key in self.bc_weights] ) self.sq_ic_weights = OrderedDict() if self.has_ic: self.sq_ic_weights = OrderedDict( [(key, self.ic_weights[key] ** 2) for key in self.ic_weights] )
[docs] def compute_preconditioning_matrix( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the preconditioning matrix using the main operator. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The preconditioning matrix. """ theta = self.get_formatted_current_theta() Phi = self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi, theta, labels, *args ) if len(self.in_weights) == 1: # apply the same weights to all labels for key in self.in_weights: # dummy loop Phi[:, :, :] *= self.in_weights[key] else: # apply weights for each labels for key in self.in_weights: Phi[labels == key, :, :] *= self.in_weights[key] return Phi
[docs] def compute_preconditioning_matrix_bc( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the boundary condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The boundary condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_bc = cast(FunctionalOperator, self.operator_bc) self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc) Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_bc, theta, labels, *args ) if len(self.bc_weights) == 1: # apply the same weights to all labels for key in self.bc_weights: # dummy loop Phi[:, :, :] *= self.bc_weights[key] else: # apply weights for each labels for key in self.bc_weights: Phi[labels == key, :, :] *= self.bc_weights[key] return Phi
[docs] def compute_preconditioning_matrix_ic( self, labels: torch.Tensor, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Compute the initial condition preconditioning matrix. Args: labels: The labels tensor. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: The initial condition preconditioning matrix. """ theta = self.get_formatted_current_theta() self.operator_ic = cast(FunctionalOperator, self.operator_ic) self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic) Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_Phi_ic, theta, labels, *args ) if len(self.ic_weights) == 1: # apply the same weights to all labels for key in self.ic_weights: # dummy loop Phi[:, :, :] *= self.ic_weights[key] else: # apply weights for each labels for key in self.ic_weights: Phi[labels == key, :, :] *= self.ic_weights[key] return Phi
[docs] def compute_preconditioning_matrix_dl( self, *args: torch.Tensor, **kwargs ) -> torch.Tensor: """Computes the Gram matrix of the network for the given input tensors. Args: *args: Input tensors for computing the Gram matrix. **kwargs: Additional keyword arguments. Returns: The computed Gram matrix. """ return self.space.jacobian(*args)
def _randomized_nystrom_approximation( self, func: Callable, rank: int ) -> tuple[torch.Tensor, torch.Tensor]: Omega_ = torch.normal( 0.0, 1.0, size=(self.p, rank), device=torch.get_default_device(), ) Omega, _ = torch.linalg.qr(Omega_, mode="reduced") # reduced QR decomposition # Y = torch.stack([func(Omega[:, i:i+1]) for i in range(rank)], dim=-1) # print("Omega.shape: ", Omega.shape) Y = func(Omega) # print("Y.shape: ", Y.shape) nu = self.eps * torch.linalg.norm(Y, "fro") Ynu = Y + nu * Omega try: C = torch.linalg.cholesky(Omega.T @ Ynu) B = torch.linalg.solve_triangular(C, Ynu.T, upper=False).T U, Sigma, _ = torch.linalg.svd( B, full_matrices=False ) # reduced SVD decomposition except RuntimeError: nu = 1e3 * nu Ynu = Y + nu * Omega C = torch.linalg.cholesky(Omega.T @ Ynu) B = torch.linalg.solve_triangular(C, Ynu.T, upper=False).T U, Sigma, _ = torch.linalg.svd( B, full_matrices=False ) # reduced SVD decomposition # B = torch.linalg.solve_triangular(C, Ynu.T, upper=False).T # U, Sigma, _ = torch.linalg.svd( # B, full_matrices=False # ) # reduced SVD decomposition # Lambda = torch.where(Sigma**2-nu > 0., Sigma**2-nu, 0.) Lambda = torch.clamp(Sigma**2 - nu, min=0.0) return U, Lambda def _unflatten(self, vector: torch.Tensor) -> dict: """Unflatten a 1D vector as model's parameters structure. Args: vector: input vector. Returns: dict: unflattened vector as dictionary. """ start = 0 unflattened_params = {} for name, param in self.named_parameters(): numel = param.numel() param_data = vector[start : start + numel].reshape(param.shape) unflattened_params[name] = param_data start += numel return unflattened_params def _flatten(self, params: dict) -> torch.Tensor: """Flatten a dictionary of parameters into a 1D vector. Args: params: Dict of parameters (e.g., model's state_dict or named_parameters). Returns: torch.Tensor: Flattened 1D vector containing all parameter values. """ return torch.cat( [params[name].reshape(-1) for name, _ in self.named_parameters()] ) def __call__( self, epoch: int, data: tuple | dict, grads: torch.Tensor, res_l: tuple, res_r: tuple, **kwargs, ) -> torch.Tensor: """Apply the preconditioner to the input gradients. Args: epoch: Current training epoch. data: Input data, either as a tuple or a dictionary. grads: Gradient tensor to be preconditioned. res_l: Left residuals. res_r: Right residuals. **kwargs: Additional keyword arguments. Returns: The preconditioned gradient tensor. """ with torch.no_grad(): gamma = self.p maxit = 20 kappa = 0.1 # For matrix free # theta = self.get_formatted_current_theta() theta_vec = self.space.get_dof(flag_scope="all", flag_format="tensor") theta = self._unflatten(cast(torch.Tensor, theta_vec)) args = self.get_args_for_operator(data) labels, args = args[0], args[1:] N = labels.shape[0] def fin_vectorized_mf(theta: dict) -> torch.Tensor: return self.operator.apply_dict_of_vmap_to_label_tensors( self.vectorized_F, theta, labels, *args ) _, vjpin = torch.func.vjp(fin_vectorized_mf, theta) def gin_vector_product_mf(v_dict: dict) -> torch.Tensor: _, vv = torch.func.jvp(fin_vectorized_mf, (theta,), (v_dict,)) if len(self.in_weights) == 1: # apply the same weight to all labels for key in self.in_weights: # dummy loop vv[:, :] *= self.sq_in_weights[key] else: for key in self.in_weights: vv[labels == key, :] *= self.sq_in_weights[key] vv = self._flatten((vjpin(vv))[0]) return (self.in_weight / N) * vv if self.has_bc: args_bc = self.get_args_for_operator_bc(data) labels_bc, args_bc = args_bc[0], args_bc[1:] N_bc = labels_bc.shape[0] def fbc_vectorized_mf(theta: dict) -> torch.Tensor: self.operator_bc = cast(FunctionalOperator, self.operator_bc) return self.operator_bc.apply_dict_of_vmap_to_label_tensors( self.vectorized_F_bc, theta, labels_bc, *args_bc ) _, vjpbc = torch.func.vjp(fbc_vectorized_mf, theta) def gbc_vector_product_mf(v_dict: dict) -> torch.Tensor: _, vv = torch.func.jvp(fbc_vectorized_mf, (theta,), (v_dict,)) if len(self.bc_weights) == 1: # apply the same weight to all labels for key in self.bc_weights: # dummy loop vv[:, :] *= self.sq_bc_weights[key] else: # apply weights for each labels for key in self.bc_weights: vv[labels_bc == key, :] *= self.sq_bc_weights[key] vv = self._flatten((vjpbc(vv))[0]) return (self.bc_weight / N_bc) * vv if self.has_ic: args_ic = self.get_args_for_operator_ic(data) labels_ic, args_ic = args_ic[0], args_ic[1:] N_ic = labels_ic.shape[0] def fic_vectorized_mf(theta: dict) -> torch.Tensor: self.operator_ic = cast(FunctionalOperator, self.operator_ic) return self.operator_ic.apply_dict_of_vmap_to_label_tensors( self.vectorized_F_ic, theta, labels_ic, *args_ic ) _, vjpic = torch.func.vjp(fic_vectorized_mf, theta) def gic_vector_product_mf(v_dict: dict) -> torch.Tensor: _, vv = torch.func.jvp(fic_vectorized_mf, (theta,), (v_dict,)) if len(self.ic_weights) == 1: # apply the same weight to all labels for key in self.ic_weights: # dummy loop vv[:, :] *= self.sq_ic_weights[key] else: # apply weights for each labels for key in self.ic_weights: vv[labels_ic == key, :] *= self.sq_ic_weights[key] vv = self._flatten((vjpic(vv))[0]) return (self.ic_weight / N_ic) * vv def g_vector_product_mf(v: torch.Tensor) -> torch.Tensor: v_dict = self._unflatten(v) res = gin_vector_product_mf(v_dict) # print("res.shape: ", res.shape) if self.has_bc: res += gbc_vector_product_mf(v_dict) if self.has_ic: res += gic_vector_product_mf(v_dict) return res g_matrix_product_mf = torch.func.vmap(g_vector_product_mf, (1), (1)) Phi = torch.ones(1) def g_vector_product_m(v: torch.Tensor) -> torch.Tensor: Phi_v = torch.einsum("np,p->n", Phi, v) # print("non mf vv.shape: ", Phi_v.shape) res = torch.einsum("pn,n->p", Phi.transpose(0, 1), Phi_v) return res def g_matrix_product_m(v: torch.Tensor) -> torch.Tensor: Phi_v = torch.einsum("np,pl->nl", Phi, v) # print("non mf vv.shape: ", Phi_v.shape) res = torch.einsum("pn,nl->pl", Phi.transpose(0, 1), Phi_v) return res g_vector_product = g_vector_product_mf g_matrix_product = g_matrix_product_mf if (not self.matrix_free) or self.debug: Phi = self.get_preconditioning_matrix(data, **kwargs) Phi = torch.cat( tuple(Phi[:, :, i, ...] for i in range(Phi.shape[2])), dim=0 ) Phi *= torch.sqrt(torch.tensor(self.in_weight / N)) if self.has_bc: Phib = self.get_preconditioning_matrix_bc(data, **kwargs) Phib = torch.cat( tuple(Phib[:, :, i, ...] for i in range(Phib.shape[2])), dim=0 ) Phib *= torch.sqrt(torch.tensor(self.bc_weight / N_bc)) Phi = torch.cat([Phi, Phib], dim=0) if self.has_ic: Phii = self.get_preconditioning_matrix_ic(data, **kwargs) Phii = torch.cat( tuple(Phii[:, :, i, ...] for i in range(Phii.shape[2])), dim=0 ) Phii *= torch.sqrt(torch.tensor(self.ic_weight / N_ic)) Phi = torch.cat([Phi, Phii], dim=0) for index, coeff in enumerate(self.dl_weights): Phil = self.get_preconditioning_matrix_dl( self.args_for_dl[index], **kwargs ) N_l = Phil.shape[0] Phil = torch.cat( tuple(Phil[:, :, i, ...] for i in range(Phil.shape[2])), dim=0 ) Phil *= torch.sqrt(torch.tensor(coeff / N_l)) Phi = torch.cat([Phi, Phil], dim=0) if not self.matrix_free: g_vector_product = g_vector_product_m g_matrix_product = g_matrix_product_m if self.debug: assert isinstance(self.ENGPrec, EnergyNaturalGradientPreconditioner) testv = torch.rand( (self.p,), device=torch.get_default_device(), dtype=torch.get_default_dtype(), ) assert torch.all(testv == self._flatten(self._unflatten(testv))) testm = torch.rand( (self.p, self.ell), device=torch.get_default_device(), dtype=torch.get_default_dtype(), ) if len(self.dl_weights) == 0: # check that matrix and matrix free version agree assert torch.allclose( g_vector_product_mf(testv), g_vector_product_m(testv) ) assert torch.allclose( g_matrix_product_mf(testm), g_matrix_product_m(testm) ) # check correctness of non regularized energy matrix G_ENG = self.ENGPrec.compute_full_preconditioning_matrix(data) assert torch.allclose(g_vector_product_m(testv), G_ENG @ testv) assert torch.allclose(g_matrix_product_m(testm), G_ENG @ testm) # eigendecomposition randomized Nystrom approximation U, Lambda = self._randomized_nystrom_approximation( g_matrix_product, self.ell ) # adaptive matrix regularization lambda_0, lambda_l = Lambda[0].item(), Lambda[-1] mu = gamma * self.eps * lambda_0 # inverse of preconditioner reg_lambda_inv = torch.diag(1.0 / (Lambda + mu)) def preconditioner(v: torch.Tensor) -> torch.Tensor: P_inv_left = U @ (reg_lambda_inv @ (U.T @ v)) P_inv_right = v - U @ (U.T @ v) return (lambda_l + mu) * P_inv_left + P_inv_right # stopping criterion for conjugate gradient rel_tol = min(kappa, torch.linalg.norm(grads)) def g_vector_product_reg(v: torch.Tensor) -> torch.Tensor: return g_vector_product(v) + mu * v preconditioned_grads, exit_code = _cg( g_vector_product, grads, rel_tol, maxit, preconditioner ) # print("exit_code: %d | %d" % (exit_code, maxit)) # import psutil # import os # import gc # process = psutil.Process(os.getpid()) # mem_info = process.memory_info() # print(f"Mémoire RAM utilisée : {mem_info.rss / 1024**2:.2f} Mo") # print("preconditioned_grads shape: ", preconditioned_grads.shape) # adjust self.l if lambda_l > 10 * mu: self.ell = min(2 * self.ell, self.p) else: self.ell = int(torch.sum(~(Lambda < 10 * mu)).item()) # print("crit: ", ~(Lambda < 10*mu)) # print("new l: %d | %d" % (self.ell, self.p)) # print("\n\n") return preconditioned_grads