scimba_torch.flows.discretization_based_flows¶
Flows based on neural networks or classical discretization schemes.
These flows can be used to define neural ODEs or other types of flows.
Classes
|
Explicit Euler flow based on a given neural network. |
|
Explicit Euler flow for a Hamiltonian system based on a given neural network. |
|
Neural flow based on a given neural network. |
|
|
|
|
|
Symplectic Euler flow for a Hamiltonian system based on a given neural network. |
|
- class NeuralFlow(in_size, out_size, param_dim, net_type, analytic_f=None, **kwargs)[source]¶
Bases:
ScimbaModuleNeural flow based on a given neural network.
- Parameters:
in_size (
int) – Size of the input to the neural network.out_size (
int) – Size of the output of the neural network.param_dim (
int) – Number of parameters in the input.net_type (
Module) – The neural network class used to approximate the solution.analytic_f (
Optional[Callable]) – An optional analytic function to be added to the network output.**kwargs (
Any) – Additional arguments passed to the neural network model.
- class ExplicitEulerFlow(in_size, out_size, param_dim, net_type, dt=0.01, analytic_f=None, **kwargs)[source]¶
Bases:
ScimbaModuleExplicit Euler flow based on a given neural network.
- Parameters:
in_size (
int) – Size of the input to the neural network.out_size (
int) – Size of the output of the neural network.param_dim (
int) – Number of parameters in the input.net_type (
Module) – The neural network class used to approximate the solution.dt (
float) – Time step for the Euler update.analytic_f (
Optional[Callable]) – An optional analytic function to be added to the network output.**kwargs – Additional arguments passed to the neural network model.
- class ExplicitEulerHamiltonianFlow(in_size, out_size, param_dim, net_type, dt=0.01, analytic_h=None, **kwargs)[source]¶
Bases:
ScimbaModuleExplicit Euler flow for a Hamiltonian system based on a given neural network.
- Parameters:
in_size (
int) – Size of the input to the neural network.out_size (
int) – Size of the output of the neural network.param_dim (
int) – Number of parameters in the input.net_type (
Module) – The neural network class used to approximate the solution.dt (
float) – Time step for the Euler update.analytic_h (
Optional[Callable]) – An optional analytic function to be added to the network output.**kwargs (
Any) – Additional arguments passed to the neural network model.
- class SymplecticEulerFlowSep(in_size, out_size, param_dim, net_type, dt=0.01, **kwargs)[source]¶
Bases:
ScimbaModuleSymplectic Euler flow for a Hamiltonian system based on a given neural network.
- Parameters:
in_size (
int) – Size of the input to the neural network.out_size (
int) – Size of the output of the neural network.param_dim (
int) – Number of parameters in the input.net_type (
Module) – The neural network class used to approximate the solution.dt (
float) – Time step for the Euler update.**kwargs (
Any) – Additional arguments passed to the neural network model.
- k_func(p, mu, params)[source]¶
Computes the kinetic energy term.
- Parameters:
p (
Tensor) – Momentum tensor.mu (
Tensor) – Parameter tensor.params (
dict) – Parameters of the neural network.
- Return type:
Tensor- Returns:
The output tensor after applying the neural network.
- u_func(q, mu, params)[source]¶
Computes the potential energy term.
- Parameters:
q (
Tensor) – Position tensor.mu (
Tensor) – Parameter tensor.params (
dict) – Parameters of the neural network.
- Return type:
Tensor- Returns:
The output tensor after applying the neural network.
- forward(x, mu)[source]¶
Symplectic Euler update.
- ..math::
q_{n+1} = q_n + dt frac{partial H}{partial p}, \ p_{n+1} = p_n - dt frac{partial H}{partial q}.
- Parameters:
x (
Tensor) – Input tensor.mu (
Tensor) – Parameter tensor.
- Return type:
Tensor- Returns:
The output tensor after applying the Symplectic Euler update.
- class VerletSymplecticEulerFlow(in_size, out_size, **kwargs)[source]¶
Bases:
ScimbaModule
- class Rk2Flow(in_size, out_size, **kwargs)[source]¶
Bases:
ScimbaModule
- class Rk4Flow(in_size, out_size, **kwargs)[source]¶
Bases:
ScimbaModule