Source code for scimba_torch.optimizers.line_search

r"""Line search functions.

Examples:
**Examples of all three line types of line search**

    .. code-block:: python

        from scimba_torch.approximation_space.nn_space import NNxSpace
        from scimba_torch.domain.meshless_domain.domain_2d import Disk2D, Square2D
        from scimba_torch.integration.monte_carlo import (
            DomainSampler,
            TensorizedSampler,
        )
        from scimba_torch.integration.monte_carlo_parameters import (
            UniformParametricSampler,
        )
        from scimba_torch.neural_nets.coordinates_based_nets.mlp import GenericMLP
        from scimba_torch.numerical_solvers.elliptic_pde.deep_ritz import (
            DeepRitzElliptic,
        )
        from scimba_torch.numerical_solvers.elliptic_pde.pinns import PinnsElliptic
        from scimba_torch.optimizers.line_search import (
            backtracking_armijo_line_search,
            backtracking_armijo_line_search_with_loss_theta_grad_loss_theta,
            logarithmic_grid_line_search,
        )
        from scimba_torch.optimizers.losses import GenericLosses
        from scimba_torch.optimizers.optimizers_data import OptimizerData
        from scimba_torch.physical_models.elliptic_pde.laplacians import (
            Laplacian2DDirichletRitzForm,
            Laplacian2DDirichletStrongForm,
        )
        from scimba_torch.utils.scimba_tensors import LabelTensor

        print(" ######################################################## ")
        print(" # line_search with a pinn with weak boundary condition # ")
        print(" ######################################################## ")

        def f_rhs(x: LabelTensor, mu: LabelTensor) -> torch.Tensor:
            x1, x2 = x.get_components()
            mu1 = mu.get_components()
            return (
                mu1
                * 8.0
                * torch.pi
                * torch.pi
                * torch.sin(2.0 * torch.pi * x1)
                * torch.sin(2.0 * torch.pi * x2)
            )

        def f_bc(x: LabelTensor, mu: LabelTensor) -> torch.Tensor:
            x1, _ = x.get_components()
            return x1 * 0.0

        domain_x = Square2D([(0.0, 1), (0.0, 1)], is_main_domain=True)
        sampler = TensorizedSampler(
            [DomainSampler(domain_x), UniformParametricSampler([(1.0, 2.0)])]
        )
        space = NNxSpace(
            1,
            1,
            GenericMLP,
            domain_x,
            sampler,
            layer_sizes=[60] * 3,
        )
        pde = Laplacian2DDirichletStrongForm(space, f=f_rhs, g=f_bc)
        losses = GenericLosses(
            [
                ("residual", torch.nn.MSELoss(), 1.0),
                ("bc", torch.nn.MSELoss(), 40.0),
            ],
        )
        opt_1 = {
            "name": "adam",
            "optimizer_args": {"lr": 2.5e-2, "betas": (0.9, 0.999)},
        }
        opt = OptimizerData(opt_1)
        pinns = PinnsElliptic(pde, bc_type="weak", optimizers=opt, losses=losses)
        n_collocation = 2000
        n_bc_collocation = 1500
        # get current parameters of the nn
        params_vect = pinns.space.get_dof(flag_scope="all", flag_format="tensor")
        # get func and derivative
        Lpinn, GradLpinn = pinns.get_loss_grad_loss(
            n_collocation=n_collocation, n_bc_collocation=n_bc_collocation
        )
        loss = Lpinn(params_vect)
        print("loss at theta = initial parameters: ", loss)
        theta = params_vect.clone().detach().requires_grad_(False)
        loss = Lpinn(theta)
        gradltheta = GradLpinn(theta)
        loss2 = Lpinn(theta)
        gradltheta2 = GradLpinn(theta)
        assert torch.equal(loss, loss2)
        assert torch.equal(gradltheta, gradltheta2)
        # perform a linesearch along gradLTheta
        print("Lpinn(theta): ", Lpinn(theta))
        eta = backtracking_armijo_line_search_with_loss_theta_grad_loss_theta(
            Lpinn,
            theta,
            loss,
            gradltheta,
            gradltheta,
            alpha=0.2,
            beta=0.5,
            n_step_max=1000,
        )
        print("eta with Armijo : ", eta)
        print("Lpinn(theta): ", Lpinn(theta))
        print("Lpinn(theta - eta * dsearch): ", Lpinn(theta - eta * gradltheta))
        assert torch.all(Lpinn(theta) > Lpinn(theta - eta * gradltheta))
        print("\n")

        eta = logarithmic_grid_line_search(
            Lpinn, theta, gradltheta, m=10, interval=[0.0, 1.0]
        )
        print("eta with logarithmic grid : ", eta)
        print("Lpinn(theta): ", Lpinn(theta))
        print("Lpinn(theta - eta * dsearch): ", Lpinn(theta - eta * gradltheta))
        assert torch.all(Lpinn(theta) > Lpinn(theta - eta * gradltheta))
        print("\n")

        # get func and derivative with new sampling points
        loss = Lpinn(theta)
        Lpinn, GradLpinn = pinns.get_loss_grad_loss(
            n_collocation=n_collocation, n_bc_collocation=n_bc_collocation
        )
        assert not torch.equal(Lpinn(theta), loss)
        # actualize theta
        theta = theta - eta * gradltheta
        loss = Lpinn(theta)
        gradltheta = GradLpinn(theta)
        # perform a linesearch along gradLTheta
        print("Lpinn(theta): ", Lpinn(theta))
        eta = backtracking_armijo_line_search_with_loss_theta_grad_loss_theta(
            Lpinn,
            theta,
            loss,
            gradltheta,
            gradltheta,
            alpha=0.2,
            beta=0.5,
            n_step_max=1000,
        )
        print("eta with Armijo : ", eta)
        print("Lpinn(theta): ", Lpinn(theta))
        print("Lpinn(theta - eta * dsearch): ", Lpinn(theta - eta * gradltheta))
        assert torch.all(Lpinn(theta) > Lpinn(theta - eta * gradltheta))
        print("\n")
        # get func and derivative with new sampling points
        loss = Lpinn(theta)
        Lpinn, GradLpinn = pinns.get_loss_grad_loss(
            n_collocation=n_collocation, n_bc_collocation=n_bc_collocation
        )
        assert not torch.equal(Lpinn(theta), loss)
        # actualize theta
        theta = theta - eta * gradltheta
        loss = Lpinn(theta)
        gradltheta = GradLpinn(theta)
        # perform a linesearch along gradltheta
        print("Lpinn(theta): ", Lpinn(theta))
        eta = backtracking_armijo_line_search_with_loss_theta_grad_loss_theta(
            Lpinn,
            theta,
            loss,
            gradltheta,
            gradltheta,
            alpha=0.2,
            beta=0.5,
            n_step_max=1000,
        )
        print("eta with Armijo : ", eta)
        print("Lpinn(theta): ", Lpinn(theta))
        print("Lpinn(theta - eta * dsearch): ", Lpinn(theta - eta * gradltheta))
        assert torch.all(Lpinn(theta) > Lpinn(theta - eta * gradltheta))
        print("\n")

        print(" ############################################################ ")
        print(" # line_search with a deep_ritz with weak boundary condition # ")
        print(" ############################################################ ")

        domain_x = Disk2D(torch.tensor([0.0, 0.0]), radius=1, is_main_domain=True)
        sampler = TensorizedSampler(
            [DomainSampler(domain_x), UniformParametricSampler([(1.0, 1.0001)])]
        )
        space = NNxSpace(
            1,
            1,
            GenericMLP,
            domain_x,
            sampler,
            layer_sizes=[30] * 3,
        )
        pde = Laplacian2DDirichletRitzForm(space, f=f_rhs, g=f_bc)
        losses = GenericLosses(
            [
                ("residual", torch.nn.MSELoss(), 1.0),
                ("bc", torch.nn.MSELoss(), 40.0),
            ],
        )
        opt_1 = {
            "name": "adam",
            "optimizer_args": {"lr": 2.5e-2, "betas": (0.9, 0.999)},
        }
        opt = OptimizerData(opt_1)
        ritz = DeepRitzElliptic(pde, bc_type="weak", optimizers=opt, losses=losses)
        n_collocation = 2000
        n_bc_collocation = 1500

        # get current parameters of the nn
        params_vect = ritz.space.get_dof(flag_scope="all", flag_format="tensor")
        # get func and derivative
        Lritz, GradLritz = ritz.get_loss_grad_loss(
            n_collocation=n_collocation, n_bc_collocation=n_bc_collocation
        )
        loss = Lritz(params_vect)
        print("loss at theta = initial parameters: ", loss)
        theta = params_vect.clone().detach().requires_grad_(False)
        loss = Lritz(theta)
        gradltheta = GradLritz(theta)
        loss2 = Lritz(theta)
        gradltheta2 = GradLritz(theta)
        assert torch.equal(loss, loss2)
        assert torch.equal(gradltheta, gradltheta2)
        # perform a linesearch along gradLTheta
        print("Lritz(theta): ", Lritz(theta))
        eta = backtracking_armijo_line_search_with_loss_theta_grad_loss_theta(
            Lritz,
            theta,
            loss,
            gradltheta,
            gradltheta,
            alpha=0.2,
            beta=0.5,
            n_step_max=1000,
        )
        print("eta with Armijo : ", eta)
        print("Lritz(theta): ", Lritz(theta))
        print("Lritz(theta - eta * dsearch): ", Lritz(theta - eta * gradltheta))
        assert torch.all(Lritz(theta) > Lritz(theta - eta * gradltheta))
        print("\n")

        eta = logarithmic_grid_line_search(
            Lritz, theta, gradltheta, m=10, interval=[0.0, 1.0]
        )
        print("eta with logarithmic grid : ", eta)
        print("Lritz(theta): ", Lritz(theta))
        print("Lritz(theta - eta * dsearch): ", Lritz(theta - eta * gradltheta))
        # assert torch.all(Lritz(theta) > Lritz(theta - eta * gradltheta))
        print("\n")

        # get func and derivative with new sampling points
        loss = Lritz(theta)
        Lritz, GradLritz = ritz.get_loss_grad_loss(
            n_collocation=n_collocation, n_bc_collocation=n_bc_collocation
        )
        assert not torch.equal(Lritz(theta), loss)
        # actualize theta
        theta = theta - eta * gradltheta
        loss = Lritz(theta)
        gradltheta = GradLritz(theta)
        # perform a linesearch along gradLTheta
        print("Lritz(theta): ", Lritz(theta))
        eta = backtracking_armijo_line_search_with_loss_theta_grad_loss_theta(
            Lritz,
            theta,
            loss,
            gradltheta,
            gradltheta,
            alpha=0.2,
            beta=0.5,
            n_step_max=1000,
        )
        print("eta with Armijo : ", eta)
        print("Lritz(theta): ", Lritz(theta))
        print("Lritz(theta - eta * dsearch): ", Lritz(theta - eta * gradltheta))
        assert torch.all(Lritz(theta) > Lritz(theta - eta * gradltheta))
        print("\n")

        eta = logarithmic_grid_line_search(
            Lritz, theta, gradltheta, m=10, interval=[0.0, 1.0]
        )
        print("eta with logarithmic grid : ", eta)
        print("Lritz(theta): ", Lritz(theta))
        print("Lritz(theta - eta * dsearch): ", Lritz(theta - eta * gradltheta))
        # assert torch.all(Lritz(theta) > Lritz(theta - eta * gradltheta))
        print("\n")

        # get func and derivative with new sampling points
        loss = Lritz(theta)
        Lritz, GradLritz = ritz.get_loss_grad_loss(
            n_collocation=n_collocation, n_bc_collocation=n_bc_collocation
        )
        assert not torch.equal(Lritz(theta), loss)
        # actualize theta
        theta = theta - eta * gradltheta
        loss = Lritz(theta)
        gradltheta = GradLritz(theta)
        # perform a linesearch along gradLTheta
        print("Lritz(theta): ", Lritz(theta))
        eta = backtracking_armijo_line_search_with_loss_theta_grad_loss_theta(
            Lritz,
            theta,
            loss,
            gradltheta,
            gradltheta,
            alpha=0.2,
            beta=0.5,
            n_step_max=1000,
        )
        print("eta with Armijo : ", eta)
        print("Lritz(theta): ", Lritz(theta))
        print("Lritz(theta - eta * dsearch): ", Lritz(theta - eta * gradltheta))
        assert torch.all(Lritz(theta) > Lritz(theta - eta * gradltheta))
        print("\n")

"""

from typing import Callable

import torch






[docs] def backtracking_armijo_line_search_with_loss_theta_grad_loss_theta( loss: Callable[[torch.Tensor], torch.Tensor], theta: torch.Tensor, # shape (p) loss_theta: torch.Tensor, # shape (1,) grad_loss_theta: torch.Tensor, # shape (p,) dsearch: torch.Tensor, # shape (p,) alpha: float = 0.01, beta: float = 0.5, n_step_max: int = 10, **kwargs, ) -> torch.Tensor: """Line search algorithm based on the Armijo condition. Args: loss: The loss function. theta: The current parameters of the loss. loss_theta: The loss at theta. grad_loss_theta: The gradient of the loss at theta. dsearch: The search direction. alpha: The Armijo condition parameter. beta: The Armijo condition parameter. n_step_max: The maximum number of steps in the backtracking algorithm. **kwargs: Arbitrary keyword arguments. Returns: An eta minimizing the loss along the search direction from theta. """ # print("n_step_max: ", n_step_max, "alpha: ", alpha, "beta: ", beta) eta = torch.tensor(1.0) dL = torch.dot(grad_loss_theta, dsearch) nbsteps = 0 while (loss(theta - eta * dsearch) > loss_theta - alpha * eta * dL) and ( nbsteps < n_step_max ): eta *= beta nbsteps += 1 return eta