"""Module for abstract ODEs."""
from abc import ABC, abstractmethod
from typing import Callable, Generator
import torch
from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
from scimba_torch.utils.typing_protocols import VarArgCallable
[docs]
def zeros_rhs(
w: MultiLabelTensor,
t: LabelTensor,
mu: LabelTensor,
nb_func: int = 1,
) -> torch.Tensor:
"""Function returning a zero right-hand side.
Args:
w: Solution tensor.
t: Temporal coordinates tensor.
mu: Parameter tensor.
nb_func: Number of functions to return (default is 1).
Returns:
A tensor of zeros with shape (number of points, nb_func).
"""
return torch.zeros(mu.shape[0], nb_func)
[docs]
class AbstractODE(ABC):
"""Base class for representing Ordinary Differential Equations (ODEs).
Args:
space: Approximation space used for the ODE
init: Initial condition function
n_equations: Number of equations in the system (default is 1)
rhs_func: Right-hand side function of the ODE (default is zeros_rhs)
**kwargs: Additional keyword arguments
"""
def __init__(
self,
space: AbstractApproxSpace,
init: Callable,
n_equations: int = 1,
rhs_func: Callable | None = None,
**kwargs,
):
super().__init__()
self.space = space
self.linear = True
# handle kwargs
self.rhs_func = rhs_func # right-hand side function
self.init = init # initial condition function
if self.rhs_func is None:
self.rhs_func = lambda w, t, mu: zeros_rhs(w, t, mu, n_equations)
self.residual_size = n_equations
self.ic_residual_size = n_equations
[docs]
def grad(
self,
w: torch.Tensor | MultiLabelTensor,
y: torch.Tensor | LabelTensor,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Compute the gradient of the tensor `w` with respect to the tensor `y`.
Args:
w: Input tensor
y: Tensor with respect to which the gradient is computed
Returns:
Gradient tensor
"""
res = self.space.grad(w, y)
if isinstance(res, Generator):
return tuple(res)
return res
[docs]
def rhs(
self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Compute the right-hand side (RHS) of the PDE.
Args:
w: Solution tensor
t: Temporal coordinate tensor
mu: Parameter tensor
Returns:
RHS tensor
"""
return self.rhs_func(w, t, mu)
[docs]
@abstractmethod
def time_operator(
self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the PDE operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
mu: Parameter tensor
Returns:
Operator tensor
"""
[docs]
def operator(
self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor
) -> torch.Tensor | tuple[torch.Tensor, ...]:
"""Apply the PDE operator.
Args:
w: Solution tensor
t: Temporal coordinate tensor
mu: Parameter tensor
Returns:
Operator tensor
"""
return self.time_operator(w, t, mu)
[docs]
def initial_condition(self, mu: LabelTensor) -> tuple[torch.Tensor, ...]:
"""Compute the initial condition.
Args:
mu: Parameter tensor
Returns:
Initial condition tensor
"""
assert self.init is not None, "Initial condition function is not defined"
return self.init(mu)
[docs]
def functional_operator_ic(
self,
func: VarArgCallable,
mu: torch.Tensor,
theta: torch.Tensor,
) -> torch.Tensor:
"""Compute the functional operator for initial conditions.
Args:
func: Callable representing the function
mu: Parameter tensor
theta: Additional parameters tensor
Returns:
Functional operator tensor for initial conditions
"""
t = torch.zeros(1)
return func(t, mu, theta)