"""Preconditioner for pinns."""
from typing import cast
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
ACCEPTED_PDE_TYPES,
TYPE_DICT_OF_VMAPS,
FunctionalOperator,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn
[docs]
class EnergyNaturalGradientPreconditioner(MatrixPreconditionerPinn):
"""Energy-based natural gradient preconditioner.
Args:
space: The approximation space.
pde: The PDE to be solved, which can be an instance of EllipticPDE,
TemporalPDE, KineticPDE, or LinearOrder2PDE.
**kwargs: Additional keyword arguments.
Keyword Args:
matrix_regularization (float): Regularization parameter for the
preconditioning matrix (default: 1e-6).
"""
def __init__(
self,
space: AbstractApproxSpace,
pde: ACCEPTED_PDE_TYPES,
**kwargs,
):
super().__init__(space, pde, **kwargs)
self.matrix_regularization = kwargs.get("matrix_regularization", 1e-6)
[docs]
def compute_preconditioning_matrix(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the preconditioning matrix using the main operator.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The preconditioning matrix.
"""
N = args[0].shape[0]
theta = self.get_formatted_current_theta()
Phi = self.operator.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi, theta, labels, *args
)
if len(self.in_weights) == 1: # apply the same weights to all labels
for key in self.in_weights: # dummy loop
Phi[:, :, :] *= self.in_weights[key]
else: # apply weights for each labels
for key in self.in_weights:
Phi[labels == key, :, :] *= self.in_weights[key]
M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N
M += self.matrix_regularization * torch.eye(self.ndof)
return self.in_weight * M
[docs]
def compute_preconditioning_matrix_bc(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the boundary condition preconditioning matrix.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The boundary condition preconditioning matrix.
"""
N = args[0].shape[0]
theta = self.get_formatted_current_theta()
self.operator_bc = cast(FunctionalOperator, self.operator_bc)
self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc)
Phi = self.operator_bc.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi_bc, theta, labels, *args
)
if len(self.bc_weights) == 1: # apply the same weights to all labels
for key in self.bc_weights: # dummy loop
Phi[:, :, :] *= self.bc_weights[key]
else: # apply weights for each labels
for key in self.bc_weights:
Phi[labels == key, :, :] *= self.bc_weights[key]
M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N
M += self.matrix_regularization * torch.eye(self.ndof)
return self.bc_weight * M
[docs]
def compute_preconditioning_matrix_ic(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the initial condition preconditioning matrix.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The initial condition preconditioning matrix.
"""
N = args[0].shape[0]
theta = self.get_formatted_current_theta()
self.operator_ic = cast(FunctionalOperator, self.operator_ic)
self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic)
Phi = self.operator_ic.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi_ic, theta, labels, *args
)
if len(self.ic_weights) == 1: # apply the same weights to all labels
for key in self.ic_weights: # dummy loop
Phi[:, :, :] *= self.ic_weights[key]
else: # apply weights for each labels
for key in self.ic_weights:
Phi[labels == key, :, :] *= self.ic_weights[key]
M = torch.einsum("ijk,ilk->jl", Phi, Phi) / N
M += self.matrix_regularization * torch.eye(self.ndof)
return self.ic_weight * M
[docs]
def compute_preconditioning_matrix_dl(
self, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Computes the Gram matrix of the network for the given input tensors.
Args:
*args: Input tensors for computing the Gram matrix.
**kwargs: Additional keyword arguments.
Returns:
The computed Gram matrix.
"""
N = args[0].shape[0]
jacobian = self.space.jacobian(*args)
M = torch.einsum(
"ijk,ilk->jl", jacobian, jacobian
) / N + self.matrix_regularization * torch.eye(self.space.ndof)
return M
[docs]
def compute_full_preconditioning_matrix(
self, data: tuple | dict, **kwargs
) -> torch.Tensor:
"""Compute the full preconditioning matrix.
Include contributions from the main operator, boundary conditions, and initial
conditions.
Args:
data: Input data, either as a tuple or a dictionary.
**kwargs: Additional keyword arguments.
Returns:
The full preconditioning matrix.
"""
M = self.get_preconditioning_matrix(data, **kwargs)
if self.has_bc:
M += self.get_preconditioning_matrix_bc(data, **kwargs)
if self.has_ic:
M += self.get_preconditioning_matrix_ic(data, **kwargs)
for index, coeff in enumerate(self.dl_weights):
M += coeff * self.get_preconditioning_matrix_dl(
self.args_for_dl[index], **kwargs
)
return M
def __call__(
self,
epoch: int,
data: tuple | dict,
grads: torch.Tensor,
res_l: tuple,
res_r: tuple,
**kwargs,
) -> torch.Tensor:
"""Apply the preconditioner to the input gradients.
Args:
epoch: Current training epoch.
data: Input data, either as a tuple or a dictionary.
grads: Gradient tensor to be preconditioned.
res_l: Left residuals.
res_r: Right residuals.
**kwargs: Additional keyword arguments.
Returns:
The preconditioned gradient tensor.
"""
M = self.compute_full_preconditioning_matrix(data, **kwargs)
preconditioned_grads = torch.linalg.lstsq(M, grads).solution
return preconditioned_grads