"""Preconditioner for pinns."""
from typing import cast
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.numerical_solvers.functional_operator import (
ACCEPTED_PDE_TYPES,
TYPE_DICT_OF_VMAPS,
FunctionalOperator,
)
from scimba_torch.numerical_solvers.preconditioner_pinns import MatrixPreconditionerPinn
[docs]
class AnagramPreconditioner(MatrixPreconditionerPinn):
"""Anagram preconditioner.
This preconditioner is based on the anagram method, which aims to improve
convergence by transforming the problem into a more favorable form.
Args:
space: The approximation space.
pde: The PDE to be solved, which can be an instance of EllipticPDE,
TemporalPDE, KineticPDE, or LinearOrder2PDE.
**kwargs: Additional keyword arguments:
Keyword Args:
`svd_threshold` (:code:`float`): Threshold for singular value
decomposition (default: 1e-6).
"""
def __init__(
self,
space: AbstractApproxSpace,
pde: ACCEPTED_PDE_TYPES,
**kwargs,
):
super().__init__(space, pde, **kwargs)
self.svd_threshold = kwargs.get("svd_threshold", 1e-6)
# if self.space.nb_unknowns > 1:
# raise ValueError(
# "Anagram preconditioner is only implemented for scalar problems."
# )
# now in MatrixPreconditionerPinn
# self.residual_size: int = _get_residual_size(pde)
# self.bc_residual_size: int = 1
# self.ic_residual_size: int = 1
#
# if self.has_bc:
# self.bc_residual_size = _get_residual_size(pde, bc=True)
#
# if self.has_ic:
# self.ic_residual_size = _get_ic_residual_size(pde)
[docs]
def compute_preconditioning_matrix(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the preconditioning matrix using the main operator.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
return self.operator.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi, theta, labels, *args
).squeeze()
[docs]
def compute_preconditioning_matrix_bc(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the boundary condition preconditioning matrix.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The boundary condition preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
self.operator_bc = cast(FunctionalOperator, self.operator_bc)
self.vectorized_Phi_bc = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_bc)
return self.operator_bc.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi_bc, theta, labels, *args
).squeeze()
[docs]
def compute_preconditioning_matrix_ic(
self, labels: torch.Tensor, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Compute the initial condition preconditioning matrix.
Args:
labels: The labels tensor.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
The initial condition preconditioning matrix.
"""
theta = self.get_formatted_current_theta()
self.operator_ic = cast(FunctionalOperator, self.operator_ic)
self.vectorized_Phi_ic = cast(TYPE_DICT_OF_VMAPS, self.vectorized_Phi_ic)
return self.operator_ic.apply_dict_of_vmap_to_label_tensors(
self.vectorized_Phi_ic, theta, labels, *args
).squeeze()
[docs]
def compute_preconditioning_matrix_dl(
self, *args: torch.Tensor, **kwargs
) -> torch.Tensor:
"""Computes the Gram matrix of the network for the given input tensors.
Args:
*args: Input tensors for computing the Gram matrix.
**kwargs: Additional keyword arguments.
Returns:
The computed Gram matrix.
"""
return self.space.jacobian(*args).squeeze(-1)
[docs]
def assemble_right_member(
self, data: tuple | dict, res_l: tuple, res_r: tuple
) -> torch.Tensor:
"""Assemble the right-hand side of the equation.
Args:
data: Input data, either as a tuple or a dictionary.
res_l: Left residuals.
res_r: Right residuals.
Returns:
The assembled right-hand side tensor.
"""
in_tup = tuple(left - right for left, right in zip(res_l, res_r))
if len(self.operator.dict_of_operators) == 1:
return torch.cat(in_tup, dim=0)
# concatenate components:
nb_comp_per_label = self.residual_size // len(self.operator.flatten_keys)
in_tup = tuple(
tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label))
for j in range(nb_comp_per_label)
)
args = self.get_args_for_operator(data)
out_tup = tuple(
# self.operator.cat_tuple_of_tensors(res, args[0], args[1])
self.operator.cat_tuple_of_tensors_along_flatten_keys(res, args[0], args[1])
for res in in_tup
)
return torch.cat(out_tup, dim=0)
[docs]
def assemble_right_member_bc(
self, data: tuple | dict, res_l: tuple, res_r: tuple
) -> torch.Tensor:
"""Assemble the right-hand side for boundary conditions.
Args:
data: Input data, either as a tuple or a dictionary.
res_l: Left residuals.
res_r: Right residuals.
Returns:
The assembled right-hand side tensor for boundary conditions.
"""
self.operator_bc = cast(FunctionalOperator, self.operator_bc)
in_tup = tuple(left - right for left, right in zip(res_l, res_r))
if len(self.operator_bc.dict_of_operators) == 1:
return torch.cat(in_tup, dim=0)
# concatenate components:
nb_comp_per_label = self.bc_residual_size // len(self.operator_bc.flatten_keys)
# print("nb_comp_per_label: ", nb_comp_per_label)
in_tup = tuple(
tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label))
for j in range(nb_comp_per_label)
)
# print("len(in_tup): ", len(in_tup))
# print("len(in_tup[0]): ", len(in_tup[0]))
args = self.get_args_for_operator_bc(data)
out_tup = tuple(
# self.operator_bc.cat_tuple_of_tensors(res, args[0], args[1])
self.operator_bc.cat_tuple_of_tensors_along_flatten_keys(
res, args[0], args[1]
)
for res in in_tup
)
return torch.cat(out_tup, dim=0)
[docs]
def assemble_right_member_ic(
self, data: tuple | dict, res_l: tuple, res_r: tuple
) -> torch.Tensor:
"""Assemble the right-hand side for initial conditions.
Args:
data: Input data, either as a tuple or a dictionary.
res_l: Left residuals.
res_r: Right residuals.
Returns:
The assembled right-hand side tensor for initial conditions.
"""
self.operator_ic = cast(FunctionalOperator, self.operator_ic)
in_tup = tuple(left - right for left, right in zip(res_l, res_r))
if len(self.operator_ic.dict_of_operators) == 1:
return torch.cat(in_tup, dim=0)
# concatenate components:
nb_comp_per_label = self.ic_residual_size // len(self.operator_ic.flatten_keys)
# print("nb_comp_per_label: ", nb_comp_per_label)
in_tup = tuple(
tuple((in_tup[i + j]) for i in range(0, len(in_tup), nb_comp_per_label))
for j in range(nb_comp_per_label)
)
args = self.get_args_for_operator_ic(data)
out_tup = tuple(
# self.operator_ic.cat_tuple_of_tensors(res, args[0], args[1])
self.operator_ic.cat_tuple_of_tensors_along_flatten_keys(
res, args[0], args[1]
)
for res in in_tup
)
return torch.cat(out_tup, dim=0)
[docs]
def assemble_right_member_dl(
self, data: tuple | dict, res_l: tuple, res_r: tuple
) -> torch.Tensor:
"""Assembles the right member for the preconditioning.
Args:
data: The input data for assembling the right member.
res_l: The left residuals.
res_r: The right residuals.
Returns:
The assembled right member.
"""
in_tup = tuple(left - right for left, right in zip(res_l, res_r))
in_tup_cat = tuple(
torch.cat(tuple(a[:, i : i + 1] for i in range(a.shape[1])), dim=0)
for a in in_tup
)
# print( (torch.cat(in_tup_cat, dim=0)).shape )
return torch.cat(in_tup_cat, dim=0)
def __call__(
self,
epoch: int,
data: tuple | dict,
grads: torch.Tensor,
res_l: tuple,
res_r: tuple,
**kwargs,
) -> torch.Tensor:
"""Apply the Anagram preconditioner to the input gradients.
Args:
epoch: Current training epoch.
data: Input data, either as a tuple or a dictionary.
grads: Gradient tensor to be preconditioned.
res_l: Left residuals.
res_r: Right residuals.
**kwargs: Additional keyword arguments.
Returns:
The preconditioned gradient tensor.
"""
phi = self.get_preconditioning_matrix(data, **kwargs)
# print("phi.shape: ", phi.shape)
if phi.ndim > 2:
phi = torch.cat(
tuple(phi[:, :, i, ...] for i in range(phi.shape[2])), dim=0
)
# print("phi.shape: ", phi.shape)
if self.has_bc:
phib = self.get_preconditioning_matrix_bc(data, **kwargs)
# print("phib.shape: ", phib.shape)
if phib.ndim > 2:
phib = torch.cat(
tuple(phib[:, :, i, ...] for i in range(phib.shape[2])), dim=0
)
# print("phib.shape: ", phib.shape)
phi = torch.cat((phi, phib), dim=0)
# print("phi.shape: ", phi.shape)
if self.has_ic:
phii = self.get_preconditioning_matrix_ic(data, **kwargs)
# print("phii.shape: ", phii.shape)
if phii.ndim > 2:
phii = torch.cat(
tuple(phii[:, :, i, ...] for i in range(phii.shape[2])), dim=0
)
# print("phii.shape: ", phi.shape)
phi = torch.cat((phi, phii), dim=0)
for index in range(self.nb_dl):
phid = self.get_preconditioning_matrix_dl(self.args_for_dl[index], **kwargs)
if phid.ndim > 2:
phid = torch.cat(
tuple(phid[:, :, i, ...] for i in range(phid.shape[2])), dim=0
)
# print("phii.shape: ", phi.shape)
phi = torch.cat((phi, phid), dim=0)
# ### compute pseudo inverse via svd
U, Delta, Vt = torch.linalg.svd(
phi,
full_matrices=False,
)
mask = Delta > self.svd_threshold # Keep only values greater than...
# print(
# "nb sv : %d, kept: %d, max: %.2e, threshold: %.2e, relative: %.2e"
# % (
# Delta.shape[0],
# torch.sum(mask),
# torch.max(Delta).item(),
# self.svd_threshold,
# torch.max(Delta).item() * self.svd_threshold,
# )
# )
Delta_inv = torch.zeros_like(Delta)
Delta_inv[mask] = 1.0 / Delta[mask]
# phi_plus = Vt.T @ torch.diag(Delta_inv) @ U.T # correct
phi_plus = Vt.T @ (Delta_inv[:, None] * U.T) # correct
#
nextindex = 0
begin = nextindex
# length = len(self.operator.dict_of_operators)
length = self.residual_size
end = nextindex + length
res = self.assemble_right_member(data, res_l[begin:end], res_r[begin:end])
nextindex += length
# print("res.shape: ", res.shape)
if self.has_bc:
self.operator_bc = cast(FunctionalOperator, self.operator_bc)
begin = nextindex
# length = len(self.operator_bc.dict_of_operators)
length = self.bc_residual_size
end = nextindex + length
resb = self.assemble_right_member_bc(
data, res_l[begin:end], res_r[begin:end]
)
# print("resb.shape: ", resb.shape)
res = torch.cat((res, resb), dim=0)
# print("res.shape: ", res.shape)
nextindex += length
if self.has_ic:
self.operator_ic = cast(FunctionalOperator, self.operator_ic)
begin = nextindex
# length = len(self.operator_ic.dict_of_operators)
length = self.ic_residual_size
end = nextindex + length
resi = self.assemble_right_member_ic(
data, res_l[begin:end], res_r[begin:end]
)
res = torch.cat((res, resi), dim=0)
nextindex += length
for index in range(self.nb_dl):
begin = nextindex
length = 1
end = nextindex + length
resd = self.assemble_right_member_dl(
data, res_l[begin:end], res_r[begin:end]
)
res = torch.cat((res, resd), dim=0)
nextindex += length
preconditioned_grads = phi_plus @ res
return preconditioned_grads