"""Module for abstract temporal PDEs."""
from abc import ABC, abstractmethod
from typing import Callable, Generator, cast
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
StrongFormEllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import LinearOrder2PDE
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
from scimba_torch.utils.typing_protocols import VarArgCallable
[docs]
class TemporalPDE(ABC):
"""Base class for representing elliptic Partial Differential Equations (PDEs).
Args:
space: Approximation space used for the PDE
linear: Whether the PDE is linear
**kwargs: Additional keyword arguments
"""
space: AbstractApproxSpace
linear: bool
exact_solution: Callable | None
def __init__(
self,
space: AbstractApproxSpace,
linear: bool = False,
**kwargs,
):
super().__init__()
self.space = space
self.linear = linear
# handle kwargs
self.exact_solution = kwargs.get("exact_solution", None)
[docs]
def grad(
self,
w: torch.Tensor | MultiLabelTensor,
y: torch.Tensor | LabelTensor,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Compute the gradient of the tensor `w` with respect to the tensor `y`.
Args:
w: Input tensor
y: Tensor with respect to which the gradient is computed
Returns:
Gradient tensor
"""
res = self.space.grad(w, y)
if isinstance(res, Generator):
return tuple(res)
return res
[docs]
@abstractmethod
def rhs(
self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Compute the right-hand side (RHS) of the PDE.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
RHS tensor
"""
[docs]
@abstractmethod
def space_operator(
self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the PDE operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
Operator tensor
"""
[docs]
@abstractmethod
def time_operator(
self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the PDE operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
Operator tensor
"""
[docs]
def operator(
self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the PDE operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
Operator tensor
"""
space = self.space_operator(w, t, x, mu)
time = self.time_operator(w, t, x, mu)
error_message = (
"space and time operators must be both Tensors of tuple of Tensors"
)
if isinstance(space, tuple):
assert isinstance(time, tuple), error_message
return tuple(sp + ti for sp, ti in zip(space, time))
else:
assert isinstance(time, torch.Tensor), error_message
return space + time
[docs]
@abstractmethod
def bc_rhs(
self,
w: MultiLabelTensor,
t: LabelTensor,
x: LabelTensor,
n: LabelTensor,
mu: LabelTensor,
) -> torch.Tensor:
"""Compute the boundary condition RHS.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
n: Normal vector tensor
mu: Parameter tensor
Returns:
Boundary condition RHS tensor
"""
[docs]
@abstractmethod
def bc_operator(
self,
w: MultiLabelTensor,
t: LabelTensor,
x: LabelTensor,
n: LabelTensor,
mu: LabelTensor,
) -> torch.Tensor:
"""Apply the boundary condition operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
n: Normal vector tensor
mu: Parameter tensor
Returns:
Boundary condition operator tensor
"""
[docs]
@abstractmethod
def initial_condition(
self, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Compute the initial condition.
Args:
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
Initial condition tensor
"""
[docs]
class FirstOrderTemporalPDE(TemporalPDE):
"""Base class for representing elliptic Partial Differential Equations (PDEs).
Args:
space: Approximation space used for the PDE
linear: Whether the PDE is linear
**kwargs: Additional keyword arguments
"""
def __init__(
self,
space: AbstractApproxSpace,
linear: bool = False,
**kwargs,
):
super().__init__(space, linear, **kwargs)
[docs]
def time_operator(
self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the PDE operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
Operator tensor
"""
u = w.get_components()
if isinstance(u, torch.Tensor):
return self.grad(u, t)
else:
return tuple(cast(torch.Tensor, self.grad(ui, t)) for ui in u)
[docs]
def functional_time_operator(
self,
func: VarArgCallable,
t: torch.Tensor,
x: torch.Tensor,
mu: torch.Tensor,
theta: torch.Tensor,
) -> torch.Tensor:
"""Compute the functional time operator.
Args:
func: Callable representing the function
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
theta: Additional parameters tensor
Returns:
Functional time operator tensor
"""
time_der = torch.func.jacrev(func, 0)(t, x, mu, theta)
return time_der[0]
[docs]
def functional_operator_ic(
self,
func: VarArgCallable,
x: torch.Tensor,
mu: torch.Tensor,
theta: torch.Tensor,
) -> torch.Tensor:
"""Compute the functional operator for initial conditions.
Args:
func: Callable representing the function
x: Spatial coordinate tensor
mu: Parameter tensor
theta: Additional parameters tensor
Returns:
Functional operator tensor for initial conditions
"""
t = torch.zeros(1)
return func(t, x, mu, theta)
[docs]
class SecondOrderTemporalPDE(TemporalPDE):
"""Base class for representing elliptic Partial Differential Equations (PDEs).
Args:
space: Approximation space used for the PDE
linear: Whether the PDE is linear
**kwargs: Additional keyword arguments
"""
def __init__(
self,
space: AbstractApproxSpace,
linear: bool = False,
**kwargs,
):
super().__init__(space, linear)
[docs]
def time_operator(
self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the PDE operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
Operator tensor
"""
# list_var = w.get_components()
# list_var_res = []
# for i in range(len(list_var)):
# list_var_res.append(self.grad(list_var[i], t))
# for i in range(len(list_var)):
# list_var_res[i] = self.grad(list_var_res[i], t)
# return tuple(list_var_res)
u = w.get_components()
if isinstance(u, torch.Tensor):
dudt = self.grad(u, t)
assert isinstance(dudt, torch.Tensor)
d2ud2t = self.grad(dudt, t)
else:
dudt = tuple(cast(torch.Tensor, self.grad(ui, t)) for ui in u)
d2ud2t = tuple(cast(torch.Tensor, self.grad(duidt, t)) for duidt in dudt)
return d2ud2t
[docs]
def zeros_rhs(
w: MultiLabelTensor,
t: LabelTensor,
x: LabelTensor,
mu: LabelTensor,
nb_func: int = 1,
) -> torch.Tensor:
"""Function returning a zero right-hand side.
Args:
w: Solution tensor.
t: Temporal coordinates tensor.
x: Spatial coordinates tensor.
mu: Parameter tensor.
nb_func: Number of functions to return (default is 1).
Returns:
A tensor of zeros with shape (number of points, nb_func).
"""
return torch.zeros(x.x.size(dim=0), nb_func)
[docs]
def zeros_bc_rhs(
w: MultiLabelTensor,
t: LabelTensor,
x: LabelTensor,
n: LabelTensor,
mu: LabelTensor,
nb_func: int = 1,
) -> torch.Tensor:
"""Function returning a zero right-hand side for the boundary conditions.
Args:
w: Solution tensor.
t: Temporal coordinates tensor.
x: Spatial coordinates tensor.
n: Normal vector tensor.
mu: Parameter tensor.
nb_func: Number of functions to return (default is 1).
Returns:
A tensor of zeros with shape (number of points, nb_func).
"""
return torch.zeros(x.x.size(dim=0), nb_func)
[docs]
def zeros_init(
x: LabelTensor,
mu: LabelTensor,
nb_func: int = 1,
) -> torch.Tensor:
"""Function returning a zero initial condition.
Args:
x: Spatial coordinates tensor.
mu: Parameter tensor.
nb_func: Number of functions to return (default is 1).
Returns:
A tensor of zeros with shape (number of points, nb_func).
"""
return torch.zeros(x.x.size(dim=0), nb_func)
SPACE_COMPONENT_TYPE = StrongFormEllipticPDE | LinearOrder2PDE
[docs]
class GenericFirstOrderTemporalPDE(FirstOrderTemporalPDE):
"""First order temporal equations extending a spatial equation.
Args:
space: The approximation space for the problem
space_component: The stationary part of the equation as a PDE
init: Callable for the initial condition (default is zero)
f: Source term function (default is zero)
g: Boundary condition function (default is zero)
**kwargs: Additional keyword arguments
"""
# static attribute
ic_residual_size: int = 1
# other attributes
f: Callable
g: Callable
init: Callable
space_component: SPACE_COMPONENT_TYPE
def __init__(
self,
space: AbstractApproxSpace,
space_component: SPACE_COMPONENT_TYPE,
init: Callable | None = None,
f: Callable | None = None,
g: Callable | None = None,
**kwargs,
):
super().__init__(space, linear=True, **kwargs)
self.space_component = space_component
self.f = zeros_rhs if f is None else f
self.g = zeros_bc_rhs if g is None else g
self.init = zeros_init if init is None else init
[docs]
def space_operator(
self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the spatial operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
Spatial operator tensor
"""
return self.space_component.operator(w, x, mu)
[docs]
def bc_operator(
self,
w: MultiLabelTensor,
t: LabelTensor,
x: LabelTensor,
n: LabelTensor,
mu: LabelTensor,
) -> torch.Tensor:
"""Apply the boundary condition operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
n: Normal vector tensor
mu: Parameter tensor
Returns:
Boundary condition operator tensor
"""
return self.space_component.bc_operator(w, x, n, mu)
[docs]
def rhs(
self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor
) -> torch.Tensor:
"""Compute the right-hand side (RHS) of the PDE.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
RHS tensor
"""
return self.f(w, t, x, mu)
[docs]
def bc_rhs(
self,
w: MultiLabelTensor,
t: LabelTensor,
x: LabelTensor,
n: LabelTensor,
mu: LabelTensor,
) -> torch.Tensor:
"""Compute the boundary condition RHS.
Args:
w: Solution tensor
t: Temporal coordinate tensor
x: Spatial coordinate tensor
n: Normal vector tensor
mu: Parameter tensor
Returns:
Boundary condition RHS tensor
"""
return self.g(w, t, x, n, mu)
[docs]
def initial_condition(self, x: LabelTensor, mu: LabelTensor) -> torch.Tensor:
"""Compute the initial condition.
Args:
x: Spatial coordinate tensor
mu: Parameter tensor
Returns:
Initial condition tensor
"""
return self.init(x, mu)
[docs]
def functional_operator(
self,
func: VarArgCallable,
t: torch.Tensor,
x: torch.Tensor,
mu: torch.Tensor,
theta: torch.Tensor,
) -> torch.Tensor:
"""Compute the functional operator.
Args:
func: Callable representing the function
t: Temporal coordinate tensor
x: Spatial coordinate tensor
mu: Parameter tensor
theta: Additional parameters tensor
Returns:
Functional operator tensor
"""
space_op = self.space_component.functional_operator(
lambda x, mu, theta: func(t, x, mu, theta), x, mu, theta
)
time_op = self.functional_time_operator(func, t, x, mu, theta)
return time_op + space_op
[docs]
def functional_operator_bc(
self,
func: VarArgCallable,
t: torch.Tensor,
x: torch.Tensor,
n: torch.Tensor,
mu: torch.Tensor,
theta: torch.Tensor,
) -> torch.Tensor:
"""Compute the functional operator for boundary conditions.
Args:
func: Callable representing the function
t: Temporal coordinate tensor
x: Spatial coordinate tensor
n: Normal vector tensor
mu: Parameter tensor
theta: Additional parameters tensor
Returns:
Functional operator tensor for boundary conditions
"""
return self.space_component.functional_operator_bc(
lambda x, mu, theta: func(t, x, mu, theta), x, n, mu, theta
)
### TODO: comment faire le splitting avec plusieurs espace et plusieurs RHS ?
### Des listes de temporal PDE avec une méthode pour que le W soit toutes les variables
### de tous les espaces ?