Source code for scimba_torch.physical_models.ode.abstract_ode

"""Module for abstract ODEs."""

from abc import ABC, abstractmethod
from typing import Callable, Generator

import torch

from scimba_torch.approximation_space.abstract_space import AbstractApproxSpace
from scimba_torch.utils.scimba_tensors import LabelTensor, MultiLabelTensor
from scimba_torch.utils.typing_protocols import VarArgCallable


[docs] def zeros_rhs( w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor, nb_func: int = 1, ) -> torch.Tensor: """Function returning a zero right-hand side. Args: w: Solution tensor. t: Temporal coordinates tensor. mu: Parameter tensor. nb_func: Number of functions to return (default is 1). Returns: A tensor of zeros with shape (number of points, nb_func). """ return torch.zeros(mu.shape[0], nb_func)
[docs] class AbstractODE(ABC): """Base class for representing Ordinary Differential Equations (ODEs). Args: space: Approximation space used for the ODE init: Initial condition function n_equations: Number of equations in the system (default is 1) rhs_func: Right-hand side function of the ODE (default is zeros_rhs) **kwargs: Additional keyword arguments """ def __init__( self, space: AbstractApproxSpace, init: Callable, n_equations: int = 1, rhs_func: Callable | None = None, **kwargs, ): super().__init__() self.space = space self.linear = True # handle kwargs self.rhs_func = rhs_func # right-hand side function self.init = init # initial condition function if self.rhs_func is None: self.rhs_func = lambda w, t, mu: zeros_rhs(w, t, mu, n_equations) self.residual_size = n_equations self.ic_residual_size = n_equations
[docs] def grad( self, w: torch.Tensor | MultiLabelTensor, y: torch.Tensor | LabelTensor, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute the gradient of the tensor `w` with respect to the tensor `y`. Args: w: Input tensor y: Tensor with respect to which the gradient is computed Returns: Gradient tensor """ res = self.space.grad(w, y) if isinstance(res, Generator): return tuple(res) return res
[docs] def rhs( self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Compute the right-hand side (RHS) of the PDE. Args: w: Solution tensor t: Temporal coordinate tensor mu: Parameter tensor Returns: RHS tensor """ return self.rhs_func(w, t, mu)
[docs] @abstractmethod def time_operator( self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor mu: Parameter tensor Returns: Operator tensor """
[docs] def operator( self, w: MultiLabelTensor, t: LabelTensor, mu: LabelTensor ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Apply the PDE operator. Args: w: Solution tensor t: Temporal coordinate tensor mu: Parameter tensor Returns: Operator tensor """ return self.time_operator(w, t, mu)
[docs] def initial_condition(self, mu: LabelTensor) -> tuple[torch.Tensor, ...]: """Compute the initial condition. Args: mu: Parameter tensor Returns: Initial condition tensor """ assert self.init is not None, "Initial condition function is not defined" return self.init(mu)
[docs] def functional_operator_ic( self, func: VarArgCallable, mu: torch.Tensor, theta: torch.Tensor, ) -> torch.Tensor: """Compute the functional operator for initial conditions. Args: func: Callable representing the function mu: Parameter tensor theta: Additional parameters tensor Returns: Functional operator tensor for initial conditions """ t = torch.zeros(1) return func(t, mu, theta)