Source code for scimba_torch.optimizers.ssbroyden

"""An implementation of Self Scaled Broyden optimizer.

_cubic_interpolate and _strong_wolfe have been copied pasted from
torch v2.9.1, in torch.optim.lbfgs.py
"""

import math
from typing import Callable, Union

import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, ParamsT


def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
    # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
    # Compute bounds of interpolation area
    if bounds is not None:
        xmin_bound, xmax_bound = bounds
    else:
        xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)

    # Code for most common case: cubic interpolation of 2 points
    #   w/ function and derivative values for both
    # Solution in this case (where x2 is the farthest point):
    #   d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
    #   d2 = sqrt(d1^2 - g1*g2);
    #   min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
    #   t_new = min(max(min_pos,xmin_bound),xmax_bound);
    d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
    d2_square = d1**2 - g1 * g2
    if d2_square >= 0:
        d2 = d2_square.sqrt()
        if x1 <= x2:
            min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
        else:
            min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
        return min(max(min_pos, xmin_bound), xmax_bound)
    else:
        return (xmin_bound + xmax_bound) / 2.0


def _strong_wolfe(
    obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25
):
    # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
    d_norm = d.abs().max()
    g = g.clone(memory_format=torch.contiguous_format)
    # evaluate objective and gradient using initial step
    f_new, g_new = obj_func(x, t, d)
    ls_func_evals = 1
    gtd_new = g_new.dot(d)

    # bracket an interval containing a point satisfying the Wolfe criteria
    t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
    done = False
    ls_iter = 0
    while ls_iter < max_ls:
        # check conditions
        if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
            bracket = [t_prev, t]
            bracket_f = [f_prev, f_new]
            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
            bracket_gtd = [gtd_prev, gtd_new]
            break

        if abs(gtd_new) <= -c2 * gtd:
            bracket = [t]
            bracket_f = [f_new]
            bracket_g = [g_new]
            done = True
            break

        if gtd_new >= 0:
            bracket = [t_prev, t]
            bracket_f = [f_prev, f_new]
            bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
            bracket_gtd = [gtd_prev, gtd_new]
            break

        # interpolate
        min_step = t + 0.01 * (t - t_prev)
        max_step = t * 10
        tmp = t
        t = _cubic_interpolate(
            t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)
        )

        # next step
        t_prev = tmp
        f_prev = f_new
        g_prev = g_new.clone(memory_format=torch.contiguous_format)
        gtd_prev = gtd_new
        f_new, g_new = obj_func(x, t, d)
        ls_func_evals += 1
        gtd_new = g_new.dot(d)
        ls_iter += 1

    # reached max number of iterations?
    if ls_iter == max_ls:
        bracket = [0, t]
        bracket_f = [f, f_new]
        bracket_g = [g, g_new]

    # zoom phase: we now have a point satisfying the criteria, or
    # a bracket around it. We refine the bracket until we find the
    # exact point satisfying the criteria
    insuf_progress = False
    # find high and low points in bracket
    low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)  # type: ignore[possibly-undefined]
    while not done and ls_iter < max_ls:
        # line-search bracket is so small
        if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:  # type: ignore[possibly-undefined]
            break

        # compute new trial value
        t = _cubic_interpolate(
            bracket[0],
            bracket_f[0],
            bracket_gtd[0],  # type: ignore[possibly-undefined]
            bracket[1],
            bracket_f[1],
            bracket_gtd[1],
        )

        # test that we are making sufficient progress:
        # in case `t` is so close to boundary, we mark that we are making
        # insufficient progress, and if
        #   + we have made insufficient progress in the last step, or
        #   + `t` is at one of the boundary,
        # we will move `t` to a position which is `0.1 * len(bracket)`
        # away from the nearest boundary point.
        eps = 0.1 * (max(bracket) - min(bracket))
        if min(max(bracket) - t, t - min(bracket)) < eps:
            # interpolation close to boundary
            if insuf_progress or t >= max(bracket) or t <= min(bracket):
                # evaluate at 0.1 away from boundary
                if abs(t - max(bracket)) < abs(t - min(bracket)):
                    t = max(bracket) - eps
                else:
                    t = min(bracket) + eps
                insuf_progress = False
            else:
                insuf_progress = True
        else:
            insuf_progress = False

        # Evaluate new point
        f_new, g_new = obj_func(x, t, d)
        ls_func_evals += 1
        gtd_new = g_new.dot(d)
        ls_iter += 1

        if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
            # Armijo condition not satisfied or not lower than lowest point
            bracket[high_pos] = t
            bracket_f[high_pos] = f_new
            bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined]
            bracket_gtd[high_pos] = gtd_new
            low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
        else:
            if abs(gtd_new) <= -c2 * gtd:
                # Wolfe conditions satisfied
                done = True
            elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
                # old high becomes new low
                bracket[high_pos] = bracket[low_pos]
                bracket_f[high_pos] = bracket_f[low_pos]
                bracket_g[high_pos] = bracket_g[low_pos]  # type: ignore[possibly-undefined]
                bracket_gtd[high_pos] = bracket_gtd[low_pos]

            # new point becomes new low
            bracket[low_pos] = t
            bracket_f[low_pos] = f_new
            bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)  # type: ignore[possibly-undefined]
            bracket_gtd[low_pos] = gtd_new

    # return stuff
    t = bracket[low_pos]  # type: ignore[possibly-undefined]
    f_new = bracket_f[low_pos]
    g_new = bracket_g[low_pos]  # type: ignore[possibly-undefined]

    if not isinstance(t, torch.Tensor):
        t = torch.tensor(float(t))
    return f_new, g_new, t, ls_func_evals


[docs] class SSBroyden(Optimizer): """Implements SSBroyden algorithm. Implementation of Urbán, J. F., Stefanou, P., & Pons, J. A. (2025). Unveiling the optimization process of physics informed neural networks: How accurate and competitive can PINNs be?. Journal of Computational Physics, 523, 113656. Args: params: iterable of parameters to optimize. Parameters must be real. lr: learning rate (default: 1) tolerance_grad: does not update if max norm of grad smaller that this. method: either "ssbroyden" or "ssbfgs" Raises: ValueError: lr is not scalar lr is <= 0. tolerance grad is <= 0. SS Broyden/BFGS doesn't support per-parameter options method is not in ["ssbfgs", "ssbroyden"] """ def __init__( self, params: ParamsT, lr: Union[float, Tensor] = 1.0, tolerance_grad: float = 1e-10, method: str = "ssbfgs", ): if isinstance(lr, float): lr = torch.tensor(lr) if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 < lr: raise ValueError(f"Invalid learning rate: {lr}") if not 0.0 < tolerance_grad: raise ValueError(f"Invalid tolerance on gradient: {tolerance_grad}") defaults = {"lr": lr, "tolerance_grad": tolerance_grad, "method": method} super().__init__(params, defaults) if len(self.param_groups) != 1: raise ValueError( "SS Broyden/BFGS doesn't support per-parameter options" " (parameter groups)" ) if method not in ["ssbfgs", "ssbroyden"]: raise ValueError("method should be either ssbroyden or ssbfgs") self._params = self.param_groups[0]["params"] self._numel_cache = None # number of parameters: nbparams = self._numel() state = self.state[self._params[0]] state["k"] = 0 state["Hk"] = torch.eye(nbparams, dtype=torch.get_default_dtype()) def _numel(self): if self._numel_cache is None: self._numel_cache = sum( 2 * p.numel() if torch.is_complex(p) else p.numel() for p in self._params ) return self._numel_cache def _gather_flat_grad(self): views = [] for p in self._params: if p.grad is None: view = p.new(p.numel()).zero_() elif p.grad.is_sparse: view = p.grad.to_dense().view(-1) else: view = p.grad.view(-1) if torch.is_complex(view): view = torch.view_as_real(view).view(-1) views.append(view) return torch.cat(views, 0) def _add_grad(self, step_size, update): offset = 0 for p in self._params: if torch.is_complex(p): p = torch.view_as_real(p) numel = p.numel() # view as to avoid deprecated pointwise semantics p.add_(update[offset : offset + numel].view_as(p), alpha=step_size) offset += numel assert offset == self._numel() def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] def _flatten(self, params): views = [] for p in params: if p.is_sparse: view = p.to_dense().view(-1) else: view = p.view(-1) if torch.is_complex(view): view = torch.view_as_real(view).view(-1) views.append(view) return torch.cat(views, 0) def _set_param(self, params_data): for p, pdata in zip(self._params, params_data): p.copy_(pdata) def _directional_evaluate(self, closure, x, t, d): self._add_grad(t, d) loss = float(closure()) flat_grad = self._gather_flat_grad() self._set_param(x) return loss, flat_grad
[docs] @torch.no_grad() def step(self, closure: Callable) -> torch.Tensor: # type: ignore[override] """Perform a single optimization step. Args: closure: A closure that reevaluates the model and returns the loss. Returns: the initial loss """ assert len(self.param_groups) == 1 # Make sure the closure is always called with grad enabled closure = torch.enable_grad()(closure) group = self.param_groups[0] lr = float(group["lr"]) tolerance_grad = group["tolerance_grad"] method = group["method"] # NOTE: SSBroyden has only global state, but we register it as state for # the first param, because this helps with casting in load_state_dict state = self.state[self._params[0]] x_init = self._clone_param() theta_k = self._flatten(x_init) def obj_func(x, t, d): return self._directional_evaluate(closure, x, t, d) # evaluate initial f(x) and df/dx orig_loss = closure() loss = float(orig_loss) grad_k = self._gather_flat_grad() opt_cond = grad_k.abs().max() <= tolerance_grad # print("grad_k: ", grad_k, ", opt_cond: ", opt_cond) if opt_cond: return orig_loss # descent direction prec_grad = state["Hk"] @ grad_k prec_grad = prec_grad.neg() # directional derivative gtd = grad_k.dot(prec_grad) # stepsize loss, grad_kp1, alpha_k, ls_func_evals = _strong_wolfe( obj_func, x_init, lr, prec_grad, loss, grad_k, gtd ) # print("alpha_k: ", alpha_k) # opt_cond = grad_kp1.abs().max() <= tolerance_grad # if opt_cond: # return orig_loss # print("grad_kp1: ", grad_kp1, ", opt_cond: ", opt_cond) # assert isinstance(grad_kp1, torch.Tensor) # print("alpha_k: ", alpha_k) # assert isinstance(alpha_k, torch.Tensor) # check that there are no nans or infs if ( math.isnan(loss) or math.isinf(loss) or torch.isnan(alpha_k) or torch.isinf(alpha_k) or torch.any(torch.isnan(grad_kp1)) or torch.any(torch.isinf(grad_kp1)) ): orig_loss = closure() return orig_loss # optimized parameters self._add_grad(alpha_k, prec_grad) # compute some values for next turn: theta_kp1 = self._flatten(self._clone_param()) s_k = theta_kp1 - theta_k # print(s_k) y_k = grad_kp1 - grad_k # print(y_k) Hkyk = state["Hk"] @ y_k yk_dot_Hkyk = y_k @ Hkyk yk_dot_sk = y_k @ s_k v_k = torch.sqrt(yk_dot_Hkyk) * (s_k / (yk_dot_sk) - Hkyk / yk_dot_Hkyk) # method ssbfgs tau_k = min(1.0, -yk_dot_sk / (alpha_k * (s_k @ grad_k))) phi_k = 1.0 if method == "ssbroyden": # raise NotImplementedError b_k = -alpha_k * (s_k @ grad_k) / yk_dot_sk h_k = yk_dot_Hkyk / yk_dot_sk a_k = h_k * b_k - 1.0 c_k = torch.sqrt(a_k / (a_k + 1.0)) rhom_k = min(1.0, h_k * (1 - c_k)) thetam_k = (rhom_k - 1) / a_k thetap_k = 1.0 / rhom_k theta_k = max(thetam_k, min(thetap_k, (1.0 - b_k) / b_k)) sigma_k = 1 + a_k * theta_k n = self._numel() sigma_k_pow = sigma_k ** (-1 / (n - 1)) if theta_k > 0: tau_k = tau_k * min(sigma_k_pow, 1.0 / theta_k) else: tau_k = min(tau_k * sigma_k_pow, sigma_k) phi_k = (1 - theta_k) / (1.0 + a_k * theta_k) # print("tau_k: ", tau_k) temp1 = (Hkyk[:, None] @ Hkyk[None, :]) / yk_dot_Hkyk temp2 = phi_k * (v_k[:, None] @ v_k[None, :]) temp3 = (s_k[:, None] @ s_k[None, :]) / yk_dot_sk H_kp1 = (1 / tau_k) * (state["Hk"] - temp1 + temp2) + temp3 # print("H_kp1: ", H_kp1) if torch.any(torch.isnan(H_kp1)): orig_loss = closure() return orig_loss state["Hk"] = H_kp1 state["k"] += 1 return orig_loss