Defining a physical model¶
This tutorial will present the main steps to implement the modelisation of a physical problem involving Partial Differential Equation (PDE) in scimba_jax, for 4 kinds of problems:
1-dimensional unknown function of space in homogeneous medium
1-dimensional unknown function of time and space in homogeneous medium
2-dimensional unknown function of time and space in homogeneous medium
1-dimensional unknown function of space in non-homogeneous medium
1-dimensional unknown function of space in homogeneous medium¶
In this first section, we will implement a 2D Laplacian with Dirichlet boundary conditions in an homogenous medium.
We consider the equation:
on \((x,y) \in \Omega := [-1, 1]\times[-1, 1]\) with parameter \(\mu\in[1,2]\) subject to the Dirichlet condition:
on \(\partial\Omega\).
The strong form of this problem will be implemented as a class, and an instance of this class will serve numerical approximation of the solution with a PINN.
Remark: the strong form of this problem is already implemented in the class ParametricLaplacianDirichletND and can be directly instantiated by users with:
[1]:
import jax
import jax.numpy as jnp
def f_rhs(xy: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray:
x, y = xy[..., 0:1], xy[..., 1:2]
return 2 * (mu[...,0:1] * jnp.pi)**2 * jnp.sin(jnp.pi * x) * jnp.sin(jnp.pi * y)
from scimba_jax.physical_models.elliptic_pde.laplacians import ParametricLaplacianDirichletND
try:
model = ParametricLaplacianDirichletND(omega, f_rhs, bc="weak") # noqa: F821
except Exception:
pass
where omega is the geometric domain \(\Omega\), f_rhs is the right hand side of the residual, bc="weak" means that boundary conditions is handled weakly; see the complete example about Scimba basics.
The rest of this section describes the implementation of this class in a pedagogical aim.
Let us first define domains and samplers for this problem (see the tutorial on domains and samplers).
[2]:
from scimba_jax.domains.meshless_domains.domains_2d import Square2D
from scimba_jax.nonlinear_approximation.integration.monte_carlo import (
DomainSampler,
TensorizedSampler,
)
from scimba_jax.nonlinear_approximation.integration.monte_carlo_parameters import (
UniformParametricSampler,
)
domain_x = Square2D([(-1.0, 1.0), (-1.0, 1.0)], is_main_domain=True)
domain_mu = [(1., 2.)]
sampler = TensorizedSampler(
[DomainSampler(domain_x), UniformParametricSampler(domain_mu)], bc=True
)
In scimba_jax, a problem is modeled with a dictionary of residuals; a residual is modeled by a functional operator that applies to the variable(s) \(u\) of the problem, and a right-hand side. We first introduce how we manipulate variables in scimba_jax.
Variables and Operators¶
In the PINN’s approach, an approximation of the solution \(u^*\) of a problem is sought in so called approximation space, that is a finite dimensional parameterized space of surrogate functions \(u_\theta\) that are Neural Networks (NN) with weights \(\theta\). An approximation is found by tuning the weights \(\theta\) to minimize the residuals:
where the \(\mathcal{R}_i\) are functional operators.
Implementing a residual in scimba_jax means implementing such a functional operator \(\mathcal{R}_i\).
For a parametric stationary problem, such a variable is a callable object like:
[3]:
WEIGHTS_TYPE = any
def u_theta(theta: WEIGHTS_TYPE, xy: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray:
pass
Let us implement the left-hand side of a Laplacian residual \(\mathcal{R}\) such that for any \(u\):
[4]:
from typing import Callable
def LaplacianResidualOperator(var: Callable) -> Callable:
#hessian of var
hessian = jax.hessian(var, argnums=1) # hessian is a function
#trace of hessian of var
trace = lambda *args: jnp.trace(hessian(*args), axis1=-2, axis2=-1) # trace is a function
#parametric laplacian operator
return lambda theta, xy, mu: -mu*trace(theta, xy, mu)
Since we know the analytic expression of the solution of the problem, we can verify soundness of our implementation by defining the variable \(u^*\):
[5]:
def analytic_solution(xy: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray:
x, y = xy[..., 0:1], xy[..., 1:2]
return mu * jnp.sin(jnp.pi * x) * jnp.sin(jnp.pi * y)
def u_star(theta: WEIGHTS_TYPE, xy: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray:
return analytic_solution(xy, mu)
R_u_star = LaplacianResidualOperator(u_star)
On any point of the geometric domain, \(\mathcal{R}(u^*)\) and \(f\) must match:
[6]:
N_COLLOC = 1000
import jax
key = jax.random.PRNGKey(0) #initialize the random generator
key, sample_dict = sampler.sample(key)
xy, mu = sample_dict["interior"]
batched_f_rhs = jax.vmap(f_rhs, in_axes=(0, 0))
batched_R_u_star = jax.vmap(R_u_star, in_axes=(None, 0, 0))
assert jnp.allclose( batched_f_rhs(xy, mu), batched_R_u_star(None, xy, mu) )
scimba_jax offers to simplify this implementation by providing special classes to manipulate variables: the ParamFunction class and its subclasses:
ParamScalarFunctionfor scalar variables (most cases),ParamFieldFunctionfor field variables (for instance gradients of scalar variables),ParamVecFunctionfor vector variables,ParamMatrixFunctionmostly used internally.
ParamFunction instances are callable objects, but offer more functionalities. For instance, our functional operator for the Laplacian residual could be rewritten:
[7]:
from scimba_jax.nonlinear_approximation.model_class.funcparam_scalar import ParamScalarFunction
def LaplacianResidualOperator(var: ParamScalarFunction) -> ParamScalarFunction:
lap = var.laplacian_x() # laplacian w.r.t x is implemented for ParamScalarFunctions
return -lap * (lambda x, mu: mu) # scalar multiplication is overloaded for ParamScalarFunctions and regular callable objects
In order to test it, we create a ParamScalarFunction instance that encapsulates u_star (this deserves a pedagogical aim; in practice, users don’t have to create ParamFunctions on their own):
[8]:
u_star_p = ParamScalarFunction(
{"x":2, "mu":1}, # a dictionary symbol: dimension
u_star, # the variable
"x_mu", # the ordering of symbols in the call
)
Then apply the operator and verify:
[9]:
R_u_star = LaplacianResidualOperator(u_star_p)
batched_R_u_star = jax.vmap(R_u_star, in_axes=(None, 0, 0))
assert jnp.allclose( batched_f_rhs(xy, mu), batched_R_u_star(None, xy, mu) )
The functional operator will now be encapsulated in a residual.
Residuals¶
There are 4 types of residuals in scimba_jax, implemented by 4 classes:
InteriorResidualmodels residuals of the interior of the geometric domain,BoundaryResidualmodels boundary conditions,InitialResidualmodels initial conditions,DataResidualallows to implement data losses.
Let us implement a class for the parametric Laplacian operator as a subclass of InteriorResidual. Such a subclass must implement a method construct_residual which construct the left-hand side of the residual for a list of variables; here, this falls back to applying the Laplasian operator.
[10]:
from scimba_jax.domains.meshless_domains.base import VolumetricDomain
from scimba_jax.physical_models.abstract_residuals import InteriorResidual
class ParametricLaplacianResidual(InteriorResidual):
def __init__(
self,
domain: VolumetricDomain,
f_rhs: Callable | None = None,
):
super().__init__(domain=domain, size=1, model_type="x_mu", f_rhs=f_rhs)
def construct_residual(self, *vars: ParamScalarFunction) -> ParamScalarFunction:
return LaplacianResidualOperator(vars[0])
Initializing the InteriorResidual part requires:
domain: the geometric domain, ascimba_jaxinstance ofVolumetricDomain,size: the size of the operator (i.e. the dimension of the target space of \(\mathcal{R}(u)\)),model_type: astrdescribing the input arguments of the variables and their order in a call,f_rhs: the right-hand side of the residual (Noneholds for the identically zero function),for time-dependent, the time domain (as explained in next parts).
[11]:
laplacian_residual = ParametricLaplacianResidual(domain_x, f_rhs)
R_u_star = laplacian_residual.construct_residual(u_star_p)
batched_R_u_star = jax.vmap(R_u_star, in_axes=(None, 0, 0))
assert jnp.allclose( batched_f_rhs(xy, mu), batched_R_u_star(None, xy, mu) )
Remark: When implementing the construct_residual method of a residual class corresponding to a functional operator \(\mathcal{R}\), it is important to keep in mind that:
construct_residualtakes in input aParamFuncand returns aParamFunc; it does not evaluate the residual;the output of
construct_residualwill be vmapped to apply to batches of vectors hence you do not have to worry about batches;construct_residualis applyied once before the training; as a consequence it might be a wrong choice to search efficiency at the expense of readability.
As a second example, we implement a subclass of BoundaryResidual for the Dirichlet residual defined as:
where the operator \(\mathcal{B}\) is identity: \(\mathcal{B}(u): ({\bf x}, \mu) \mapsto u({\bf x}, \mu)\).
Because the boundaries of a geometric domain can be composite, a BoundaryResidual instance is attached to a dictionary of instances of SurfacicDomain, the class modeling domains of dimension \(d'<d\) embedded in \(\mathbb{R}^d\):
[12]:
from scimba_jax.domains.meshless_domains.base import SurfacicDomain
from scimba_jax.physical_models.abstract_residuals import BoundaryResidual
class DirichletResidual(BoundaryResidual):
def __init__(
self,
domain: dict[str, SurfacicDomain],
size: int = 1,
model_type: str = "x_mu",
f_rhs: Callable | None = None,
time_domain: tuple[float, ...] | None = None,
):
super().__init__(
domain=domain,
size=size,
model_type=model_type,
f_rhs=f_rhs,
time_domain=time_domain,
)
def construct_residual(self, *vars):
# Here assume that variables are functions of normals
rho = vars[0]
return rho
Notice that althought the model type here is "x_mu", the right-hand side, if provided, is required to depend explicitly on the normals, as in:
[13]:
def f_bc_rhs(xy: jnp.ndarray, n: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray:
x = xy[..., 0:1]
return jnp.zeros_like(x)
and it is assumed in the method construct_residual that input variables depend also on the normals:
[14]:
def u_star_with_normals(theta: WEIGHTS_TYPE, xy: jnp.ndarray, n: jnp.ndarray, mu: jnp.ndarray) -> jnp.ndarray:
return analytic_solution(xy, mu)
u_star_wn = ParamScalarFunction(
{"x":2, "n":2, "mu":1}, # a dictionary symbol: dimension
u_star_with_normals, # the variable
"x_n_mu", # the ordering of symbols in the call
)
Remark: variables with normals as argument are automatically created in instances of BoundaryResidual.
By default, all the boundary elements of a VolumetricDomain which is a main domain are grouped in a dictionary; this dictionary has label "boundary":
[15]:
boundaries_dict_of_dict = domain_x.get_all_boundaries()
#by default, this returns a dictionary with a unique entry with label "boundary"
#which is a dictionary
print(boundaries_dict_of_dict)
dirichlet_residual = DirichletResidual(boundaries_dict_of_dict["boundary"], 1, "x_mu", f_bc_rhs, None)
B_u_star = dirichlet_residual.construct_residual(u_star_wn)
batched_R_u_star = jax.vmap(B_u_star, in_axes=(None, 0, 0, 0))
batched_f_bc_rhs = jax.vmap(f_bc_rhs, in_axes=(0, 0, 0))
xy, n, mu = sample_dict["boundary"]
assert jnp.allclose( batched_f_bc_rhs(xy, n, mu), batched_R_u_star(None, xy, n, mu) )
{'boundary': {'bc south': SurfacicDomain with label "bc south" of type Segment2D, 'bc east': SurfacicDomain with label "bc east" of type Segment2D, 'bc north': SurfacicDomain with label "bc north" of type Segment2D, 'bc west': SurfacicDomain with label "bc west" of type Segment2D}}
InteriorResidual, BoundaryResidual and InitialResidual are all subclasses of PhysicalResidual.
Physical models¶
We are ready implement the class ParametricLaplacianDirichletND that models the condidered problem. It will inherit from the abstract class AbstractPhysicalModel of which a simplified version is reproduced below.
A physical model in scimba_jax is a dictionary of PhysicalResidual instances (and possibly a dictionary of DataResidual instances) attached to geometric and temporal domains:
[16]:
import warnings
from abc import ABC
from scimba_jax.physical_models.abstract_residuals import DataResidual, PhysicalResidual
PHYSICAL_RESIDUALS_TYPE = dict[str, PhysicalResidual]
DATA_RESIDUALS_TYPE = dict[str, DataResidual]
class SimplifiedAbstractPhysicalModel(ABC):
main_domain: VolumetricDomain
time_domain: tuple[float, ...]
dim: int
boundaries: dict[str, dict[str, SurfacicDomain]]
physical_residuals: PHYSICAL_RESIDUALS_TYPE
data_residuals: DATA_RESIDUALS_TYPE
def __init__(
self,
main_domain: VolumetricDomain,
time_domain: tuple[float, float] | None = None,
):
self.main_domain = main_domain
self.time_domain = tuple() if time_domain is None else time_domain
self.dim = self.main_domain.dim + int(time_domain is None)
assert self.main_domain.is_main_domain
try:
self.boundaries = self.main_domain.get_all_boundaries()
except NotImplementedError:
self.boundaries = {group: {} for group in self.main_domain.boundaries}
self.physical_residuals = {}
self.data_residuals = {}
def get_physical_residuals(self) -> PHYSICAL_RESIDUALS_TYPE:
return self.physical_residuals
def get_data_residuals(self) -> DATA_RESIDUALS_TYPE:
return self.data_residuals
# user can add data_residuals later
def add_data_residual(self, label: str, residual: DataResidual):
if label in self.data_residuals:
warnings.warn("replacing data residual %s" % label, RuntimeWarning)
self.data_residuals[label] = residual
Notice the automatic construction of the dictionary of dictionaries of boundary elements.
Here comes the class ParametricLaplacianDirichletND which gathers the interior and boundary residuals in the dictionary of physical residuals:
[17]:
from scimba_jax.physical_models.abstract_physical_model import AbstractPhysicalModel
class ParametricLaplacianDirichletND(AbstractPhysicalModel):
"""A n D laplacian equation with one parameter."""
def __init__(
self,
main_domain: VolumetricDomain,
f_rhs: Callable | None = None,
bc: str = "weak",
f_bc_rhs: Callable | None = None,
):
super().__init__(main_domain=main_domain)
self.physical_residuals: PHYSICAL_RESIDUALS_TYPE = {
self.main_domain.get_label(): ParametricLaplacianResidual(
domain=main_domain,
f_rhs=f_rhs,
),
}
if bc == "weak":
for boundary in self.boundaries:
self.physical_residuals[boundary] = DirichletResidual(
domain=self.boundaries[boundary],
model_type="x_mu",
f_rhs=f_bc_rhs,
)
The boundary residual is created for all dictionaries of boundary elements.
Let us finally instantiate our model:
[18]:
model = ParametricLaplacianDirichletND(domain_x, f_rhs, bc="weak")
Define and train a PINN for our physical model:¶
To finish with, we define and train a PINN to approximate the solution of the considered problem:
[19]:
import matplotlib.pyplot as plt
from scimba_jax.nonlinear_approximation.networks.mlp import MLP
from scimba_jax.nonlinear_approximation.approximation_spaces.approximation_spaces import (
ApproximationSpace,
)
from scimba_jax.nonlinear_approximation.numerical_solvers.projectors import Projector
from scimba_jax.plots.plots_nd import plot_abstract_approx_space
key = jax.random.PRNGKey(0) #initialize the random generator
nn = MLP(in_size=3, out_size=1, hidden_sizes=[16, 16, 16], key=key)
space = ApproximationSpace(
{"x": 2, "mu": 1}, # x has dim 2, mu has dim 1
[(nn, "scalar", None)], # nn models a scalar function u
model_type="x_mu", # u depends of x and mu
)
The ApproximationSpace object automatically generates the variables of the problem.
[20]:
pinn = Projector(model, space, sampler, matrix_regularization=1e-5)
key, pinn = pinn.project(key, space, 100, 1000)
plot_abstract_approx_space(
pinn.space, # the approximation space
domain_x, # the spatial domain
domain_mu, # the parameter's domain
loss=pinn.losses, # for plot of the loss: the losses
residual=pinn.model, # for plot of the residual: the pde
error=analytic_solution, # for plot of the error with respect to a func: the func
draw_contours=True,
n_drawn_contours=20,
)
plt.show()
Training: 100%|||||||||||||||||| 100/100[00:03<00:00] , loss: 6.3e+02 -> 2.2e-05
1-dimensional unknown function of time and space in homogeneous medium¶
We consider a time-dependant 1D heat equation in an homogenous medium:
on \(x \in \Omega := [0, 1]\) subject to the Dirichlet condition:
on \(\partial\Omega\), and the initial condition:
The strong form of this problem will be implemented as a class of which an instance will be used for solving with a PINN.
Remark: the strong form of this problem is already implemented in the class HeatND, see the complete example about Scimba basics.
The rest of this section describes the implementation of this class in a pedagogical aim.
Let us first define domains and samplers for this problem (see the tutorial on domains and samplers).
[21]:
from scimba_jax.domains.meshless_domains.domains_1d import Segment1D
from scimba_jax.nonlinear_approximation.integration.monte_carlo_time import (
UniformTimeSampler,
)
domain_t = (0.0, 1.0)
domain_x = Segment1D((0.0, 1.0), is_main_domain=True)
sampler = TensorizedSampler(
[
UniformTimeSampler(domain_t),
DomainSampler(domain_x),
],
model_type="t_x",
bc=True,
ic=True,
)
The analytic expression of the solution and the right-hand sides of the interior residual and the boundary conditions are time-dependent:
[22]:
def analytic_solution(t: jnp.ndarray, x: jnp.ndarray):
return jnp.exp(-(t * jnp.pi**2)) * jnp.sin(jnp.pi * x)
def f_rhs(t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
return jnp.zeros_like(x)
def f_bc_rhs(t: jnp.ndarray, x: jnp.ndarray, n: jnp.ndarray) -> jnp.ndarray:
return jnp.zeros_like(x)
while the initial condition is not:
[23]:
def f_init(x: jnp.ndarray):
t = jnp.zeros_like(x)
return analytic_solution(t, x)
Variables¶
Variables manipulated in time-dependent problems are callable objet with prototypes:
[24]:
def u_theta(theta: WEIGHTS_TYPE, t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
pass
u_theta_p = ParamScalarFunction(
{"t":1, "x":1}, # a dictionary symbol: dimension
u_theta, # the variable
"t_x", # the ordering of symbols in the call
)
def u_star(theta: WEIGHTS_TYPE, t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
return analytic_solution(t, x)
u_star_p = ParamScalarFunction(
{"t":1, "x":1},
u_star,
"t_x",
)
Residuals¶
We will now implement the three residuals of our problem.
First the heat equation residual, which applies to the interior of the domain. Notice now the initialization of the InteriorResidual object with named argument time_domain; notice also the model_type="t_x":
[25]:
class HeatResidual(InteriorResidual):
def __init__(
self,
domain: VolumetricDomain,
time_domain: tuple[float, float],
f_rhs: Callable | None = None,
):
super().__init__(
domain=domain,
size=1,
model_type="t_x",
f_rhs=f_rhs,
time_domain=time_domain,
)
def construct_residual(self, *vars):
rho = vars[0]
assert isinstance(rho, ParamScalarFunction)
lap = rho.laplacian_x()
dt_rho = rho.d_t()
return dt_rho - lap
We instantiate it and check soundness on the analytic solution:
[26]:
heat_residual = HeatResidual(domain_x, domain_t, f_rhs)
R_u_star = heat_residual.construct_residual(u_star_p)
key, sample_dict = sampler.sample(key)
t, x = sample_dict["interior"]
batched_f_rhs = jax.vmap(f_rhs, in_axes=(0, 0))
batched_R_u_star = jax.vmap(R_u_star, in_axes=(None, 0, 0))
assert jnp.allclose( batched_f_rhs(t, x), batched_R_u_star(None, t, x) )
The class DirichletResidual is already defined above, and it remains to define a class for the initial condition that will inherit from InitialResidual:
[27]:
from scimba_jax.physical_models.abstract_residuals import InitialResidual
class ProjectionResidual(InitialResidual):
def __init__(
self,
domain: VolumetricDomain,
time_domain: float | tuple[float] = (0.0,),
size: int = 1,
model_type: str = "t_x",
f_rhs: Callable | None = None, # a function that does not depend on time
):
super().__init__(
domain=domain,
time_domain=(
(time_domain,) if isinstance(time_domain, float) else time_domain
),
size=size,
model_type=model_type,
f_rhs=f_rhs,
)
def construct_residual(self, *vars):
rho = vars[0]
return rho.set_t_0(self.time_domain[0])
The set_t_0 method of a ParamScalarFunction fixes the value of \(t\); here construct_residual construct a ParamScalarFunction of type x from a ParamScalarFunction of type t_x:
[28]:
initial_residual = ProjectionResidual(domain_x, domain_t[0], f_rhs=f_init)
I_u_star = initial_residual.construct_residual(u_star_p)
assert I_u_star.f_type == "x"
key, sample_dict = sampler.sample(key)
x, = sample_dict["ic interior"]
batched_f_init = jax.vmap(f_init, in_axes=(0))
batched_I_u_star = jax.vmap(I_u_star, in_axes=(None, 0))
assert jnp.allclose( batched_f_init(x), batched_I_u_star(None, x) )
Physical model:¶
As previously, we contruct a class to gather the three residuals in a dictionary:
[29]:
class HeatND(AbstractPhysicalModel):
def __init__(
self,
main_domain: VolumetricDomain,
time_domain: tuple[float, float],
f_rhs: Callable | None = None,
bc: str = "weak",
f_bc_rhs: Callable | None = None,
ic: str = "weak",
f_ic_rhs: Callable | None = None,
):
super().__init__(main_domain=main_domain, time_domain=time_domain)
self.physical_residuals: PHYSICAL_RESIDUALS_TYPE = {
self.main_domain.get_label(): HeatResidual(
domain=main_domain,
time_domain=time_domain,
f_rhs=f_rhs,
),
}
if bc == "weak":
for boundary in self.boundaries:
self.physical_residuals[boundary] = DirichletResidual(
domain=self.boundaries[boundary],
time_domain=time_domain,
model_type="t_x",
f_rhs=f_bc_rhs,
)
if ic == "weak":
label = "ic " + self.main_domain.get_label()
self.physical_residuals[label] = ProjectionResidual(
domain=main_domain,
time_domain=(self.time_domain[0],),
model_type="t_x",
f_rhs=f_ic_rhs,
)
Let us finally instantiate our model:
[30]:
model = HeatND(domain_x, domain_t, bc="weak", ic="weak", f_ic_rhs=f_init)
The right-hand sides for the interior and the boundary residual are zero functions.
Define and train a PINN for our physical model:¶
To finish with, we define and train a PINN to approximate the solution of the considered problem:
[31]:
key = jax.random.PRNGKey(0)
nn = MLP(in_size=2, out_size=1, hidden_sizes=[16, 16], key=key)
space = ApproximationSpace(
{"x": 2}, # t is implicitly of dim 1
[(nn, "scalar", None)],
model_type="t_x",
)
pinn = Projector(model, space, sampler)
key, pinn = pinn.project(
key, space, 50, 1000, 1000, 1000
)
plot_abstract_approx_space(
pinn.space,
domain_x,
time_domain=domain_t,
time_values=[0.0, .5, 1.],
loss=pinn.losses,
residual=pinn.model,
solution=analytic_solution,
error=analytic_solution,
title="learning sol of 1D heat equation with TemporalPinns",
)
plt.show()
Training: 100%|||||||||||||||||||| 50/50[00:01<00:00] , loss: 6.5e+00 -> 3.4e-07
2-dimensional unknown function of time and space in homogeneous medium¶
We consider a time-dependant 1D problem with 2-dimensional unknown function \(u=(u_1, u_2)\) (which is called Linearized Euler problem):
on \(x \in \Omega := [0, 1]\) subject to the Dirichlet condition:
on \(\partial\Omega\), and the initial condition:
where \(D = 0.02\).
We first define the domains and a sampler; once again, the parametric domain is empty.
[32]:
t_min, t_max = 0.0, 0.5
domain_t = (t_min, t_max)
domain_x = Segment1D((-1.0, 3.0), is_main_domain=True)
sampler = TensorizedSampler(
[
UniformTimeSampler(domain_t),
DomainSampler(domain_x),
],
model_type="t_x",
bc=True,
ic=True,
)
Rather than seeing the residuals as systemes of two equations in two variables \(u_1\) and \(u_2\), we consider one vectorial variable \({\bf u}\) taking its values in \(\mathbb{R}^2\):
The considered problem has a solution with analytic expression:
[33]:
def analytic_solution(t, x):
D = 0.02
coeff = 1 / (4 * jnp.pi * D) ** 0.5
u1_plus_u2 = coeff * jnp.exp(-((x - t - 1) ** 2) / (4 * D))
u1_minus_u2 = coeff * jnp.exp(-((x + t - 1) ** 2) / (4 * D))
u1 = (u1_plus_u2 + u1_minus_u2) / 2
u2 = (u1_plus_u2 - u1_minus_u2) / 2
return jnp.concatenate((u1, u2), axis=-1)
which specializes in the initial solution:
[34]:
def initial_solution(x):
return analytic_solution(jnp.zeros_like(x), x)
This time, we will manipulate ParamVecFunction instances of size 2 like:
[35]:
from scimba_jax.nonlinear_approximation.model_class.funcparam_vectorial import ParamVecFunction
def u_star(theta: WEIGHTS_TYPE, t: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
return analytic_solution(t, x)
u_star_p = ParamVecFunction(
2, #the out size
{"t":1, "x":1}, #the in sizes
u_star,
"t_x",
)
Here is how one can implement the Linearized Euler residual:
[36]:
class LinearizedEulerResidual(InteriorResidual):
def __init__(
self,
domain: VolumetricDomain,
time_domain: tuple[float, float],
f_rhs: Callable | None = None,
):
super().__init__(
domain=domain,
size=2,
model_type="t_x",
f_rhs=f_rhs,
time_domain=time_domain,
)
def construct_residual(self, *vars):
rho = vars[0]
assert isinstance(rho, ParamVecFunction)
dt_rho = rho.d_t()
dx_rho = rho.partial_derivative_x(0)
dx_rho = dx_rho.roll(1)
return dt_rho + dx_rho
We instantiate it and check soundness on the analytic solution:
[37]:
linearized_euler_residual = LinearizedEulerResidual(domain_x, domain_t)
R_u_star = linearized_euler_residual.construct_residual(u_star_p)
key, sample_dict = sampler.sample(key)
t, x = sample_dict["interior"]
batched_f_rhs = jax.vmap(f_rhs, in_axes=(0, 0))
batched_R_u_star = jax.vmap(R_u_star, in_axes=(None, 0, 0))
assert jnp.allclose( batched_f_rhs(t, x), batched_R_u_star(None, t, x) )
Let us implement Linearized Euler model as:
[38]:
class LinearizedEuler(AbstractPhysicalModel):
def __init__(
self,
main_domain: VolumetricDomain,
time_domain: tuple[float, float],
f_rhs: Callable | None = None,
bc: str = "weak",
f_bc_rhs: Callable | None = None,
ic: str = "weak",
f_ic_rhs: Callable | None = None,
):
super().__init__(main_domain=main_domain, time_domain=time_domain)
self.physical_residuals: PHYSICAL_RESIDUALS_TYPE = {
self.main_domain.get_label(): LinearizedEulerResidual(
domain=main_domain,
time_domain=time_domain,
f_rhs=f_rhs,
),
}
if bc == "weak":
for boundary in self.boundaries:
self.physical_residuals[boundary] = DirichletResidual(
domain=self.boundaries[boundary],
time_domain=time_domain,
size=2,
model_type="t_x",
f_rhs=f_bc_rhs,
)
if ic == "weak":
label = "ic " + self.main_domain.get_label()
self.physical_residuals[label] = ProjectionResidual(
domain=main_domain,
time_domain=(self.time_domain[0],),
size=2,
model_type="t_x",
f_rhs=f_ic_rhs,
)
and instantiate it:
[39]:
model = LinearizedEuler(
main_domain=domain_x, time_domain=domain_t, f_ic_rhs=initial_solution
)
Define and train a PINN for our physical model:¶
To finish with, we define and train a PINN to approximate the solution of the considered problem.
When creating the approximation space, we specify that the variable associated with the NN is a ParamVecFunction of size 2:
[40]:
key = jax.random.PRNGKey(0)
nn = MLP(in_size=2, out_size=2, hidden_sizes=[32, 32], key=key)
space = ApproximationSpace(
{"x": 1}, # t is implicitely of dim 1
[(nn, "vec", 2)], # here we specify that the variable associated to nn is
model_type="t_x"
)
pinn = Projector(model, space, sampler)
key, pinn = pinn.project(
key, space, 50, 1000, 2000, 2000
)
plot_abstract_approx_space(
pinn.space,
domain_x,
time_domain=domain_t,
components=[0, 1],
loss=pinn.losses,
residual=pinn.model,
solution=analytic_solution,
error=analytic_solution,
derivatives=["ux", "ut"],
title="solving LinearizedEuler with TemporalPinns",
)
plt.show()
Training: 100%|||||||||||||||||||| 50/50[00:10<00:00] , loss: 3.9e+00 -> 6.4e-07
1-dimensional unknown function of space in non-homogeneous medium¶
We consider in this section a electrostatic problem which consist in solving a lagrangian on a domain composed of two nested rectangles: \(\Omega := [-1, 1]\times[-1, 1]\) and \(\Omega_{i} := [-1/2, 1/2]\times[-1/4, 1/4]\).
\(\Omega_{o} := \Omega \setminus \Omega_{i}\) and \(\Omega_{i}\) represent two different materials.
The problem is:
with boundary conditions on \(\partial \Omega_{o}\):
with a discontinuity of \(\frac{\partial}{\partial x}u\) and \(\frac{\partial}{\partial y}u\) on \(\partial \Omega_{i}\):
Domains and samplers¶
We first define the geometric domains \(\Omega_{o}\) and \(\Omega_{i}\) as sub-domains of a scimba_jax domain representing \(\Omega\):
[41]:
bounds_x = [(-1.0, 1.0), (-1.0, 1.0)]
bounds_ix = [(-0.5, 0.5), (-0.25, 0.25)]
omega = Square2D(bounds_x, is_main_domain=True, label_str="outer")
omega_i = Square2D(bounds_ix, is_main_domain=False, label_str="inner")
omega.add_subdomain(omega_i)
The boundaries elements of omega and omega are respectively called:
[42]:
_ = tuple(print(bc.get_label()) for bc in omega.full_bc_domain())
bc south
bc east
bc north
bc west
[43]:
_ = tuple(print(bc.get_label()) for bc in omega_i.full_bc_domain())
bc inner south
bc inner east
bc inner north
bc inner west
We group these boundaries element in a dictionary attached to omega:
[44]:
omega.set_boundaries_dict(
{
"bc SN": ["bc south", "bc north"],
"bc E": ["bc east"],
"bc W": ["bc west"],
"bc inner_SN": ["bc inner south", "bc inner north"],
"bc inner_WE": ["bc inner west", "bc inner east"],
}
)
and define a sampler:
[45]:
sampler = TensorizedSampler(
[DomainSampler(omega)], bc=True, model_type="x",
)
Let us try the sampler:
[46]:
key = jax.random.PRNGKey(0)
key, dict_of_samples = sampler.sample(key, 1000)
def plot_dict_of_samples_2d(
dict_of_samples: dict[str, tuple[jnp.ndarray, ...]],
fig=None, index=111
):
if fig is None:
fig = plt.figure()
axe=fig.add_subplot(index)
for label in dict_of_samples:
x, y = dict_of_samples[label][0][:,0], dict_of_samples[label][0][:,1]
axe.scatter(x, y, s=2, label=label)
axe.legend()
axe.axis('equal')
return fig
#dict_of_samples = dict_of_samples | sampler.bc_sample(key, 5000)
plot_dict_of_samples_2d(dict_of_samples)
plt.show()
Notice that the samples labeled “outer” are points in \(\Omega_o\) and not in \(\Omega\).
To know more about scimba_jax domains, follow the tutorial on domains and samplers.
Neural Network and Unknown function¶
The main pittfall that one might face when approximating a solution of this problem with a neural network is the discontinuity of the unknown \(u\) at interface of \(\Omega_{o}\) and \(\Omega_{i}\).
To tackle it, we will approximate \(u\) with a function with 2 dimensional output, say
where \(u^o_{\theta}\) approximates \(u\) on \(\Omega_o\) and \(u^i_{\theta}\) approximates \(u\) on \(\Omega_i\).
Let us create a approximation space:
[47]:
key = jax.random.PRNGKey(0)
nn = MLP(in_size=2, out_size=2, hidden_sizes=[16, 16, 16], key=key)
space = ApproximationSpace({"x": 2}, [(nn, "vec", 2)], model_type="x")
Residuals¶
The interior residual becomes a system of equations:
We define a dedicated residual class for this scaled Laplacian residual that will be instantiated two times, once for each equation:
[48]:
class BiScaledLaplacian(InteriorResidual):
def __init__(
self,
scale: float,
component: int,
domain: VolumetricDomain,
f_rhs: Callable | None = None,
):
super().__init__(domain=domain, size=1, model_type="x", f_rhs=f_rhs)
self.scale = scale
self.component = component
def construct_residual(self, *vars):
rho = vars[0]
assert isinstance(rho, ParamVecFunction)
# select component
rho_c = rho.component(self.component)
# compute laplacian
lap_c = rho_c.laplacian_x()
# scale by factor
return -self.scale * lap_c
Boundary conditions rewrite:
The first equation is a Neumann condition while the second and the third are Dirichlet conditions.
We group the remaining ones in two systems of two equations.
The Neumann condition:
[49]:
from scimba_jax.nonlinear_approximation.model_class.funcparam_field import ParamFieldFunction
class BiNeumannResidual(BoundaryResidual):
def __init__(
self,
component: int,
domain: dict[str, SurfacicDomain],
f_rhs: Callable | None = None,
):
super().__init__(
domain=domain,
size=1,
model_type="x",
f_rhs=f_rhs,
time_domain=None,
)
self.component = component
def construct_residual(self, *vars):
rho = vars[0]
assert isinstance(rho, ParamVecFunction)
rho_index = rho.component(self.component)
grad_x_rho = rho_index.gradient_x() # grad_x_rho is a ParamFieldFunction
n_func = ParamFieldFunction(
rho.dims,
lambda *args: args[rho.argnums["n"]],
rho.f_type,
) # n_func is a ParamFieldFunction that takes value "n", i.e. the normals
grad_n_rho = grad_x_rho.dot(n_func) #dimensions match
return grad_n_rho
The Dirichlet condition:
[50]:
class BiDirichletResidual(BoundaryResidual):
def __init__(
self,
component: int,
domain: dict[str, SurfacicDomain],
f_rhs: Callable | None = None,
):
super().__init__(
domain=domain,
size=1,
model_type="x",
f_rhs=f_rhs,
time_domain=None,
)
self.component = component
def construct_residual(self, *vars):
rho = vars[0]
assert isinstance(rho, ParamVecFunction)
rho_index = rho.component(self.component)
return rho_index
And the interface condition:
[51]:
class InterfaceResidual(BoundaryResidual):
def __init__(
self,
factor: float,
component: int,
domain: dict[str, SurfacicDomain],
f_rhs: Callable | None = None,
):
super().__init__(domain=domain, size=2, model_type="x", f_rhs=f_rhs)
self.factor = factor
self.component = component
def construct_residual(self, *vars):
rho = vars[0]
assert isinstance(rho, ParamVecFunction)
# split outputs
rho_i = rho.component(0)
rho_o = rho.component(1)
# compute partial derivative with respect to symbol
rho_i_index = rho_i.partial_derivative_x(component_index=self.component)
rho_o_index = rho_o.partial_derivative_x(component_index=self.component)
# scale by factor
equality = rho_i - rho_o
discontinuity = rho_i_index - self.factor * rho_o_index
return ParamVecFunction.cat([equality, discontinuity])
Physical model¶
We group right-hand sides of residuals in dictionaries:
[52]:
f_rhs = {
"inner": lambda *args: 0.0,
"outer": lambda *args: 0.0,
}
f_bc_rhs = {
"bc SN": lambda *args: 0.0,
"bc E": lambda *args: -1.0,
"bc W": lambda *args: 1.0,
"bc inner_SN": lambda *args: 0.0,
"bc inner_WE": lambda *args: 0.0,
}
And define the physical model corresponding to our bi-material electrostatic problem:
[53]:
class BiMaterialElectrostatic(AbstractPhysicalModel):
"""BiMaterialElectrostatic model."""
def __init__(self, main_domain, f_rhs, f_bc_rhs):
super().__init__(main_domain=main_domain)
self.e_r = 3.3
# inner residuals
self.physical_residuals = {
"outer": BiScaledLaplacian(
scale=1.0, component=0, domain=main_domain, f_rhs=f_rhs["outer"]
),
"inner": BiScaledLaplacian(
scale=5.0, component=1, domain=main_domain, f_rhs=f_rhs["inner"]
),
}
# boundary residuals
self.physical_residuals["bc SN"] = BiNeumannResidual(
component=0,
domain=self.boundaries["bc SN"],
f_rhs=f_bc_rhs["bc SN"],
)
self.physical_residuals["bc E"] = BiDirichletResidual(
component=0,
domain=self.boundaries["bc E"],
f_rhs=f_bc_rhs["bc E"],
)
self.physical_residuals["bc W"] = BiDirichletResidual(
component=0,
domain=self.boundaries["bc W"],
f_rhs=f_bc_rhs["bc W"],
)
self.physical_residuals["bc inner_SN"] = InterfaceResidual(
factor=self.e_r,
component=1,
domain=self.boundaries["bc inner_SN"],
f_rhs=f_bc_rhs["bc inner_SN"],
)
self.physical_residuals["bc inner_WE"] = InterfaceResidual(
factor=self.e_r,
component=0,
domain=self.boundaries["bc inner_WE"],
f_rhs=f_bc_rhs["bc inner_WE"],
)
Let us instantiate it:
[54]:
model = BiMaterialElectrostatic(omega, f_rhs, f_bc_rhs)
Define and train a PINN for our physical model:¶
To finish with, we define and train a PINN to approximate the solution of the considered problem.
[55]:
N_EPOCHS = 400
N_COLLOC = 2000
N_BC_COLLOC = 6000
pinn = Projector(model, space, sampler, matrix_regularization=5e-4)
key, pinn = pinn.project(
key, space, N_EPOCHS, N_COLLOC, N_BC_COLLOC
)
plot_abstract_approx_space(
pinn.space,
omega,
domain_mu,
loss=pinn.losses,
residual=pinn.model,
derivatives=["ux"],
components=[{"outer": 0, "inner": 1}], # to assemble components in a single plot
draw_contours=True,
n_drawn_contours=20,
title="Solving Bimaterial Electrostatic",
loss_groups=["bc"], # to group losses corresponding to boundary conditions
)
plt.show()
Training: 100%|||||||||||||||||| 400/400[00:33<00:00] , loss: 1.7e+01 -> 1.3e-02
[ ]: