Source code for scimba_torch.physical_models.temporal_pde.abstract_temporal_pde

"""Module for abstract temporal PDEs."""

from abc import ABC, abstractmethod
from typing import Callable, Generator, cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.physical_models.elliptic_pde.abstract_elliptic_pde import (
    StrongFormEllipticPDE,
)
from scimba_torch.physical_models.elliptic_pde.linear_order_2 import LinearOrder2PDE
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
from scimba_torch.utils.typing_protocols import VarArgCallable


[docs] class TemporalPDE(ABC): """Base class for representing elliptic Partial Differential Equations (PDEs). Args: space: Approximation space used for the PDE linear: Whether the PDE is linear **kwargs: Additional keyword arguments """ space: AbstractApproxSpace linear: bool exact_solution: Callable | None def __init__( self, space: AbstractApproxSpace, linear: bool = False, **kwargs, ): super().__init__() self.space = space self.linear = linear # handle kwargs self.exact_solution = kwargs.get("exact_solution", None)
[docs] def grad( self, w: torch.Tensor | MultiLabelTensor, y: torch.Tensor | LabelTensor, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute the gradient of the tensor `w` with respect to the tensor `y`. Args: w: Input tensor y: Tensor with respect to which the gradient is computed Returns: Gradient tensor """ res = self.space.grad(w, y) if isinstance(res, Generator): return tuple(res) return res
[docs] @abstractmethod def rhs( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute the right-hand side (RHS) of the PDE. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: RHS tensor """
[docs] @abstractmethod def space_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """
[docs] @abstractmethod def time_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """
[docs] def operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """ space = self.space_operator(w, t, x, mu) time = self.time_operator(w, t, x, mu) error_message = ( "space and time operators must be both Tensors of tuple of Tensors" ) if isinstance(space, tuple): assert isinstance(time, tuple), error_message return tuple(sp + ti for sp, ti in zip(space, time)) else: assert isinstance(time, torch.Tensor), error_message return space + time
[docs] @abstractmethod def bc_rhs( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: """Compute the boundary condition RHS. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor n: Normal vector tensor mu: Parameter tensor Returns: Boundary condition RHS tensor """
[docs] @abstractmethod def bc_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: """Apply the boundary condition operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor n: Normal vector tensor mu: Parameter tensor Returns: Boundary condition operator tensor """
[docs] @abstractmethod def initial_condition( self, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute the initial condition. Args: x: Spatial coordinate tensor mu: Parameter tensor Returns: Initial condition tensor """
[docs] class FirstOrderTemporalPDE(TemporalPDE): """Base class for representing elliptic Partial Differential Equations (PDEs). Args: space: Approximation space used for the PDE linear: Whether the PDE is linear **kwargs: Additional keyword arguments """ def __init__( self, space: AbstractApproxSpace, linear: bool = False, **kwargs, ): super().__init__(space, linear, **kwargs)
[docs] def time_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """ u = w.get_components() if isinstance(u, torch.Tensor): return self.grad(u, t) else: return tuple(cast(torch.Tensor, self.grad(ui, t)) for ui in u)
[docs] def functional_time_operator( self, func: VarArgCallable, t: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: """Compute the functional time operator. Args: func: Callable representing the function t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor theta: Additional parameters tensor Returns: Functional time operator tensor """ time_der = torch.func.jacrev(func, 0)(t, x, mu, theta) return time_der[0]
[docs] def functional_operator_ic( self, func: VarArgCallable, x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: """Compute the functional operator for initial conditions. Args: func: Callable representing the function x: Spatial coordinate tensor mu: Parameter tensor theta: Additional parameters tensor Returns: Functional operator tensor for initial conditions """ t = torch.zeros(1) return func(t, x, mu, theta)
[docs] class SecondOrderTemporalPDE(TemporalPDE): """Base class for representing elliptic Partial Differential Equations (PDEs). Args: space: Approximation space used for the PDE linear: Whether the PDE is linear **kwargs: Additional keyword arguments """ def __init__( self, space: AbstractApproxSpace, linear: bool = False, **kwargs, ): super().__init__(space, linear)
[docs] def time_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """ # list_var = w.get_components() # list_var_res = [] # for i in range(len(list_var)): # list_var_res.append(self.grad(list_var[i], t)) # for i in range(len(list_var)): # list_var_res[i] = self.grad(list_var_res[i], t) # return tuple(list_var_res) u = w.get_components() if isinstance(u, torch.Tensor): dudt = self.grad(u, t) assert isinstance(dudt, torch.Tensor) d2ud2t = self.grad(dudt, t) else: dudt = tuple(cast(torch.Tensor, self.grad(ui, t)) for ui in u) d2ud2t = tuple(cast(torch.Tensor, self.grad(duidt, t)) for duidt in dudt) return d2ud2t
[docs] def zeros_rhs( w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor, nb_func: int = 1, ) -> torch.Tensor: """Function returning a zero right-hand side. Args: w: Solution tensor. t: Temporal coordinates tensor. x: Spatial coordinates tensor. mu: Parameter tensor. nb_func: Number of functions to return (default is 1). Returns: A tensor of zeros with shape (number of points, nb_func). """ return torch.zeros(x.x.size(dim=0), nb_func)
[docs] def zeros_bc_rhs( w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, nb_func: int = 1, ) -> torch.Tensor: """Function returning a zero right-hand side for the boundary conditions. Args: w: Solution tensor. t: Temporal coordinates tensor. x: Spatial coordinates tensor. n: Normal vector tensor. mu: Parameter tensor. nb_func: Number of functions to return (default is 1). Returns: A tensor of zeros with shape (number of points, nb_func). """ return torch.zeros(x.x.size(dim=0), nb_func)
[docs] def zeros_init( x: LabelTensor, mu: LabelTensor, nb_func: int = 1, ) -> torch.Tensor: """Function returning a zero initial condition. Args: x: Spatial coordinates tensor. mu: Parameter tensor. nb_func: Number of functions to return (default is 1). Returns: A tensor of zeros with shape (number of points, nb_func). """ return torch.zeros(x.x.size(dim=0), nb_func)
SPACE_COMPONENT_TYPE = StrongFormEllipticPDE | LinearOrder2PDE
[docs] class GenericFirstOrderTemporalPDE(FirstOrderTemporalPDE): """First order temporal equations extending a spatial equation. Args: space: The approximation space for the problem space_component: The stationary part of the equation as a PDE init: Callable for the initial condition (default is zero) f: Source term function (default is zero) g: Boundary condition function (default is zero) **kwargs: Additional keyword arguments """ # static attribute ic_residual_size: int = 1 # other attributes f: Callable g: Callable init: Callable space_component: SPACE_COMPONENT_TYPE def __init__( self, space: AbstractApproxSpace, space_component: SPACE_COMPONENT_TYPE, init: Callable | None = None, f: Callable | None = None, g: Callable | None = None, **kwargs, ): super().__init__(space, linear=True, **kwargs) self.space_component = space_component self.f = zeros_rhs if f is None else f self.g = zeros_bc_rhs if g is None else g self.init = zeros_init if init is None else init
[docs] def space_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the spatial operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Spatial operator tensor """ return self.space_component.operator(w, x, mu)
[docs] def bc_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: """Apply the boundary condition operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor n: Normal vector tensor mu: Parameter tensor Returns: Boundary condition operator tensor """ return self.space_component.bc_operator(w, x, n, mu)
[docs] def rhs( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor: """Compute the right-hand side (RHS) of the PDE. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: RHS tensor """ return self.f(w, t, x, mu)
[docs] def bc_rhs( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: """Compute the boundary condition RHS. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor n: Normal vector tensor mu: Parameter tensor Returns: Boundary condition RHS tensor """ return self.g(w, t, x, n, mu)
[docs] def initial_condition(self, x: LabelTensor, mu: LabelTensor) -> torch.Tensor: """Compute the initial condition. Args: x: Spatial coordinate tensor mu: Parameter tensor Returns: Initial condition tensor """ return self.init(x, mu)
[docs] def functional_operator( self, func: VarArgCallable, t: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: """Compute the functional operator. Args: func: Callable representing the function t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor theta: Additional parameters tensor Returns: Functional operator tensor """ space_op = self.space_component.functional_operator( lambda x, mu, theta: func(t, x, mu, theta), x, mu, theta ) time_op = self.functional_time_operator(func, t, x, mu, theta) return time_op + space_op
[docs] def functional_operator_bc( self, func: VarArgCallable, t: torch.Tensor, x: torch.Tensor, n: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: """Compute the functional operator for boundary conditions. Args: func: Callable representing the function t: Temporal coordinate tensor x: Spatial coordinate tensor n: Normal vector tensor mu: Parameter tensor theta: Additional parameters tensor Returns: Functional operator tensor for boundary conditions """ return self.space_component.functional_operator_bc( lambda x, mu, theta: func(t, x, mu, theta), x, n, mu, theta )
### TODO: comment faire le splitting avec plusieurs espace et plusieurs RHS ? ### Des listes de temporal PDE avec une méthode pour que le W soit toutes les variables ### de tous les espaces ?