Saving and loading PINNs

Scimba PINNs can be saved in and loaded from files, allowing to:

  • save time and energy by avoiding to train again a PINN

  • use a PINN trained on another machine or another device.

Let us first define a function to create a PINN approximating the solution of the laplacian problem of the previous tutorial before describing these utilities.

[1]:
import torch

import scimba_torch
from scimba_torch.approximation_space.nn_space import NNxSpace
from scimba_torch.domain.meshless_domain.domain_2d import Square2D
from scimba_torch.integration.monte_carlo import DomainSampler, TensorizedSampler
from scimba_torch.integration.monte_carlo_parameters import UniformParametricSampler
from scimba_torch.numerical_solvers.elliptic_pde.pinns import NaturalGradientPinnsElliptic
from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP
from scimba_torch.physical_models.elliptic_pde.laplacians import Laplacian2DDirichletStrongForm
from scimba_torch.utils.scimba_tensors import LabelTensor

torch.manual_seed(0)

x_bounds = [(0.0, 1), (0.0, 1)]
mu_bounds = [(1.0, 2.0)]



def f_rhs(xs: LabelTensor, ms: LabelTensor) -> torch.Tensor:
    x, y = xs.get_components()
    mu = ms.get_components()
    pi = torch.pi
    return 2 * (2.0 * mu * pi) ** 2 * torch.sin(2.0 * pi * x) * torch.sin(2.0 * pi * y)

def f_bc(xs: LabelTensor, ms: LabelTensor) -> torch.Tensor:
    x, _ = xs.get_components()
    return x * 0.0

def my_pinn(layer_sizes=[64], verbose=False):

    domain_x = Square2D(x_bounds, is_main_domain=True)
    domain_mu = mu_bounds

    sampler = TensorizedSampler( [DomainSampler(domain_x), UniformParametricSampler(domain_mu)])

    space = NNxSpace(1, 1, GenericMLP, domain_x, sampler, layer_sizes=layer_sizes)

    pde = Laplacian2DDirichletStrongForm(space, f=f_rhs, g=f_bc)

    pinn = NaturalGradientPinnsElliptic(pde, bc_type="weak")

    return pinn

In the following we construct and train a PINN for approximating the solution of our laplacian problem:

[2]:
pinn = my_pinn()
pinn.solve(epochs=200, n_collocation=1000, n_bc_collocation=1000, verbose=False)
print("loss after optimization: ", pinn.best_loss)
loss after optimization:  0.0024674155487880212

save and load methods

This training has been done on a GPU; lets save it with the save method of PINNs:

[3]:
scimba_torch.print_torch_setting()

pinn.save("tutorial")
torch device: cuda:0
torch floating point format: torch.float64
cuda devices:        1
cuda current device: 0
cuda device name:    Tesla V100-PCIE-16GB

This has created a file tutorial.pt in the path ~/.scimba/scimba_torch (where ~ is the current home directory). We will se below how to change this behavior.

Let us first show how to load this PINN, possibly on another device or machine.

[4]:
torch.set_default_device("cpu")
torch.set_default_dtype(torch.float32)

scimba_torch.print_torch_setting()
torch device: cpu
torch floating point format: torch.float32
cuda devices:        1
cuda current device: 0
cuda device name:    Tesla V100-PCIE-16GB

First create a PINN, then load it from the file that has just been created while verifying that the loss is preserved.

[5]:
pinn2 = my_pinn()
print("loss before load: ", pinn2.best_loss)

pinn2.load("tutorial")
print("loss after load: ", pinn2.best_loss)
loss before load:  inf
loss after load:  0.0024674155487880212

It is import to load the file in a PINN that has the same network architecture as the one that has been saved.

[6]:
pinn3 = my_pinn(layer_sizes=[32,32])
pinn3.load("tutorial")
/home/u2/imbach/new_scimba/src/scimba_torch/numerical_solvers/abstract_projector.py:746: RuntimeWarning: loading state dict in file /home/u2/imbach/.scimba/scimba_torch/tutorial.pt: something went wrong; maybe the nn has not the same size
  warn(
[6]:
False

Notice the return value False, that also arise when the input file does not exists:

[7]:
pinn3.load("not_existing_file")
/home/u2/imbach/new_scimba/src/scimba_torch/numerical_solvers/abstract_projector.py:728: RuntimeWarning: trying to load state dict in file /home/u2/imbach/.scimba/scimba_torch/not_existing_file.pt: file does not exists; do nothing
  warn(
[7]:
False

Now one can for instance plot the approximation:

[8]:
import matplotlib.pyplot as plt
from scimba_torch.plots.plots_nd import plot_abstract_approx_spaces

plot_abstract_approx_spaces(
    pinn2.space,  # the approximation space
    Square2D(x_bounds, is_main_domain=True),  # the spatial domain
    mu_bounds,  # the parameter's domain
    loss=pinn2.losses,  # optional plot of the loss: the losses
    draw_contours=True,  # plotting isolevel lines
    n_drawn_contours=20,  # number of isolevel lines,
    title=(
        "Approximating the solution of "
        r"$- \mu\Delta u = 2(2\mu\pi)^2 \sin(2\pi x) \sin(2\pi y)$"
        " with Scimba Pinns"
    ),
)
plt.show()
../_images/tutorials_saving_loading_15_0.png

One can also continue the training of a PINN; we first switch back to our GPU and double precision.

[9]:
torch.set_default_dtype(torch.float64)
torch.set_default_device("cuda")

scimba_torch.print_torch_setting()

pinn2 = my_pinn()
pinn2.solve(epochs=100, n_collocation=1000, n_bc_collocation=1000, verbose=False)

plot_abstract_approx_spaces(
    pinn2.space,  # the approximation space
    Square2D(x_bounds, is_main_domain=True),  # the spatial domain
    mu_bounds,  # the parameter's domain
    loss=pinn2.losses,  # optional plot of the loss: the losses
    draw_contours=True,  # plotting isolevel lines
    n_drawn_contours=20,  # number of isolevel lines,
    title=(
        "Approximating the solution of "
        r"$- \mu\Delta u = 2(2\mu\pi)^2 \sin(2\pi x) \sin(2\pi y)$"
        " with Scimba Pinns"
    ),
)
plt.show()
torch device: cuda:0
torch floating point format: torch.float64
cuda devices:        1
cuda current device: 0
cuda device name:    Tesla V100-PCIE-16GB
../_images/tutorials_saving_loading_17_1.png

save and load location

When scimba is verbose, information is displayed at save and load actions:

[10]:
scimba_torch.set_verbosity(True)

pinn2.save("tutorial")

pinn2.load("tutorial")

/////////////// Scimba 1.0.0 ////////////////
torch device: cuda:0
torch floating point format: torch.float64
cuda devices:        1
cuda current device: 0
cuda device name:    Tesla V100-PCIE-16GB


>> saving state dict in file /home/u2/imbach/.scimba/scimba_torch/tutorial.pt
>> loading state dict in file /home/u2/imbach/.scimba/scimba_torch/tutorial.pt
[10]:
True

By default, PINNs are saved/loaded in the file ~/.scimba/scimba_torch/YOUR_NAME.pt where YOUR_NAME is the name you want to give to the file passed as first argument of the save/load methods.

YOUR_NAME can be the name of a script, for instance "tutorial.py" or the pythonvariable __file__; in this case, the extension will be ignored.

[11]:
pinn2.save("tutorial.py")
pinn2.load("tutorial")
>> saving state dict in file /home/u2/imbach/.scimba/scimba_torch/tutorial.pt
>> loading state dict in file /home/u2/imbach/.scimba/scimba_torch/tutorial.pt
[11]:
True

One can also specify a post-fix for the filename:

[12]:
pinn2.save("tutorial", "retrained")
pinn2.load("tutorial", "retrained")
>> saving state dict in file /home/u2/imbach/.scimba/scimba_torch/tutorial_retrained.pt
>> loading state dict in file /home/u2/imbach/.scimba/scimba_torch/tutorial_retrained.pt
[12]:
True

and change the path of the destination directory:

[13]:
pinn2.save("tutorial", "retrained", path="~/saved_PINNS")
pinn2.load("tutorial", "retrained", path="~/saved_PINNS")
>> saving state dict in file ~/saved_PINNS/.scimba/scimba_torch/tutorial_retrained.pt
>> loading state dict in file ~/saved_PINNS/.scimba/scimba_torch/tutorial_retrained.pt
[13]:
True

or the all destination directory:

[14]:
pinn2.save("tutorial", "retrained", path="~", folder_name="scimba_PINNs")
pinn2.load("tutorial", "retrained", path="~", folder_name="scimba_PINNs")
>> saving state dict in file ~/scimba_PINNs/tutorial_retrained.pt
>> loading state dict in file ~/scimba_PINNs/tutorial_retrained.pt
[14]:
True
[ ]: