Source code for scimba_torch.physical_models.temporal_pde.abstract_temporal_pde

"""Module for abstract temporal PDEs."""

from abc import ABC, abstractmethod
from typing import Generator, cast

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor


[docs] class TemporalPDE(ABC): """Base class for representing elliptic Partial Differential Equations (PDEs). Args: space: Approximation space used for the PDE linear: Whether the PDE is linear **kwargs: Additional keyword arguments """ def __init__( self, space: AbstractApproxSpace, linear: bool = False, **kwargs, ): super().__init__() self.space = space self.linear = linear # handle kwargs self.exact_solution = kwargs.get("exact_solution", None)
[docs] def grad( self, w: torch.Tensor | MultiLabelTensor, y: torch.Tensor | LabelTensor, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute the gradient of the tensor `w` with respect to the tensor `y`. Args: w: Input tensor y: Tensor with respect to which the gradient is computed Returns: Gradient tensor """ res = self.space.grad(w, y) if isinstance(res, Generator): return tuple(res) return res
[docs] @abstractmethod def rhs( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute the right-hand side (RHS) of the PDE. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: RHS tensor """
[docs] @abstractmethod def space_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """
[docs] @abstractmethod def time_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """
[docs] def operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """ space = self.space_operator(w, t, x, mu) time = self.time_operator(w, t, x, mu) error_message = ( "space and time operators must be both Tensors of tuple of Tensors" ) if isinstance(space, tuple): assert isinstance(time, tuple), error_message return tuple(sp + ti for sp, ti in zip(space, time)) else: assert isinstance(time, torch.Tensor), error_message return space + time
[docs] @abstractmethod def bc_rhs( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: """Compute the boundary condition RHS. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor n: Normal vector tensor mu: Parameter tensor Returns: Boundary condition RHS tensor """
[docs] @abstractmethod def bc_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, n: LabelTensor, mu: LabelTensor, ) -> torch.Tensor: """Apply the boundary condition operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor n: Normal vector tensor mu: Parameter tensor Returns: Boundary condition operator tensor """
[docs] @abstractmethod def initial_condition( self, x: LabelTensor, mu: LabelTensor ) -> tuple[torch.Tensor, ...]: """Compute the initial condition. Args: x: Spatial coordinate tensor mu: Parameter tensor Returns: Initial condition tensor """
[docs] class FirstOrderTemporalPDE(TemporalPDE): """Base class for representing elliptic Partial Differential Equations (PDEs). Args: space: Approximation space used for the PDE linear: Whether the PDE is linear **kwargs: Additional keyword arguments """ def __init__( self, space: AbstractApproxSpace, linear: bool = False, **kwargs, ): super().__init__(space, linear, **kwargs)
[docs] def time_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """ u = w.get_components() if isinstance(u, torch.Tensor): return self.grad(u, t) else: return tuple(cast(torch.Tensor, self.grad(ui, t)) for ui in u)
[docs] class SecondOrderTemporalPDE(TemporalPDE): """Base class for representing elliptic Partial Differential Equations (PDEs). Args: space: Approximation space used for the PDE linear: Whether the PDE is linear **kwargs: Additional keyword arguments """ def __init__( self, space: AbstractApproxSpace, linear: bool = False, **kwargs, ): super().__init__(space, linear)
[docs] def time_operator( self, w: MultiLabelTensor, t: LabelTensor, x: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor x: Spatial coordinate tensor mu: Parameter tensor Returns: Operator tensor """ # list_var = w.get_components() # list_var_res = [] # for i in range(len(list_var)): # list_var_res.append(self.grad(list_var[i], t)) # for i in range(len(list_var)): # list_var_res[i] = self.grad(list_var_res[i], t) # return tuple(list_var_res) u = w.get_components() if isinstance(u, torch.Tensor): dudt = self.grad(u, t) assert isinstance(dudt, torch.Tensor) d2ud2t = self.grad(dudt, t) else: dudt = tuple(cast(torch.Tensor, self.grad(ui, t)) for ui in u) d2ud2t = tuple(cast(torch.Tensor, self.grad(duidt, t)) for duidt in dudt) return d2ud2t
### TODO: comment faire le splitting avec plusieurs espace et plusieurs RHS ? ### Des listes de temporal PDE avec une méthode pour que le W soit toutes les variables ### de tous les espaces ?