"""An implementation of Self Scaled Broyden optimizer.
_cubic_interpolate and _strong_wolfe have been copied pasted from
torch v2.9.1, in torch.optim.lbfgs.py
"""
import math
from typing import Callable, Union
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer, ParamsT
def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
# ported from https://github.com/torch/optim/blob/master/polyinterp.lua
# Compute bounds of interpolation area
if bounds is not None:
xmin_bound, xmax_bound = bounds
else:
xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
# Code for most common case: cubic interpolation of 2 points
# w/ function and derivative values for both
# Solution in this case (where x2 is the farthest point):
# d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
# d2 = sqrt(d1^2 - g1*g2);
# min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
# t_new = min(max(min_pos,xmin_bound),xmax_bound);
d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
d2_square = d1**2 - g1 * g2
if d2_square >= 0:
d2 = d2_square.sqrt()
if x1 <= x2:
min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
else:
min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
return min(max(min_pos, xmin_bound), xmax_bound)
else:
return (xmin_bound + xmax_bound) / 2.0
def _strong_wolfe(
obj_func, x, t, d, f, g, gtd, c1=1e-4, c2=0.9, tolerance_change=1e-9, max_ls=25
):
# ported from https://github.com/torch/optim/blob/master/lswolfe.lua
d_norm = d.abs().max()
g = g.clone(memory_format=torch.contiguous_format)
# evaluate objective and gradient using initial step
f_new, g_new = obj_func(x, t, d)
ls_func_evals = 1
gtd_new = g_new.dot(d)
# bracket an interval containing a point satisfying the Wolfe criteria
t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
done = False
ls_iter = 0
while ls_iter < max_ls:
# check conditions
if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
bracket_gtd = [gtd_prev, gtd_new]
break
if abs(gtd_new) <= -c2 * gtd:
bracket = [t]
bracket_f = [f_new]
bracket_g = [g_new]
done = True
break
if gtd_new >= 0:
bracket = [t_prev, t]
bracket_f = [f_prev, f_new]
bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
bracket_gtd = [gtd_prev, gtd_new]
break
# interpolate
min_step = t + 0.01 * (t - t_prev)
max_step = t * 10
tmp = t
t = _cubic_interpolate(
t_prev, f_prev, gtd_prev, t, f_new, gtd_new, bounds=(min_step, max_step)
)
# next step
t_prev = tmp
f_prev = f_new
g_prev = g_new.clone(memory_format=torch.contiguous_format)
gtd_prev = gtd_new
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1
# reached max number of iterations?
if ls_iter == max_ls:
bracket = [0, t]
bracket_f = [f, f_new]
bracket_g = [g, g_new]
# zoom phase: we now have a point satisfying the criteria, or
# a bracket around it. We refine the bracket until we find the
# exact point satisfying the criteria
insuf_progress = False
# find high and low points in bracket
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0) # type: ignore[possibly-undefined]
while not done and ls_iter < max_ls:
# line-search bracket is so small
if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change: # type: ignore[possibly-undefined]
break
# compute new trial value
t = _cubic_interpolate(
bracket[0],
bracket_f[0],
bracket_gtd[0], # type: ignore[possibly-undefined]
bracket[1],
bracket_f[1],
bracket_gtd[1],
)
# test that we are making sufficient progress:
# in case `t` is so close to boundary, we mark that we are making
# insufficient progress, and if
# + we have made insufficient progress in the last step, or
# + `t` is at one of the boundary,
# we will move `t` to a position which is `0.1 * len(bracket)`
# away from the nearest boundary point.
eps = 0.1 * (max(bracket) - min(bracket))
if min(max(bracket) - t, t - min(bracket)) < eps:
# interpolation close to boundary
if insuf_progress or t >= max(bracket) or t <= min(bracket):
# evaluate at 0.1 away from boundary
if abs(t - max(bracket)) < abs(t - min(bracket)):
t = max(bracket) - eps
else:
t = min(bracket) + eps
insuf_progress = False
else:
insuf_progress = True
else:
insuf_progress = False
# Evaluate new point
f_new, g_new = obj_func(x, t, d)
ls_func_evals += 1
gtd_new = g_new.dot(d)
ls_iter += 1
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point
bracket[high_pos] = t
bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
bracket_gtd[high_pos] = gtd_new
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
else:
if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied
done = True
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low
bracket[high_pos] = bracket[low_pos]
bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
bracket_gtd[high_pos] = bracket_gtd[low_pos]
# new point becomes new low
bracket[low_pos] = t
bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
bracket_gtd[low_pos] = gtd_new
# return stuff
t = bracket[low_pos] # type: ignore[possibly-undefined]
f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
if not isinstance(t, torch.Tensor):
t = torch.tensor(float(t))
return f_new, g_new, t, ls_func_evals
[docs]
class SSBroyden(Optimizer):
"""Implements SSBroyden algorithm.
Implementation of
Urbán, J. F., Stefanou, P., & Pons, J. A. (2025).
Unveiling the optimization process of physics informed neural networks:
How accurate and competitive can PINNs be?.
Journal of Computational Physics, 523, 113656.
Args:
params: iterable of parameters to optimize. Parameters must be real.
lr: learning rate (default: 1)
tolerance_grad: does not update if max norm of grad smaller that this.
method: either "ssbroyden" or "ssbfgs"
Raises:
ValueError:
lr is not scalar
lr is <= 0.
tolerance grad is <= 0.
SS Broyden/BFGS doesn't support per-parameter options
method is not in ["ssbfgs", "ssbroyden"]
"""
def __init__(
self,
params: ParamsT,
lr: Union[float, Tensor] = 1.0,
tolerance_grad: float = 1e-10,
method: str = "ssbfgs",
):
if isinstance(lr, float):
lr = torch.tensor(lr)
if isinstance(lr, Tensor) and lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
if not 0.0 < lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 < tolerance_grad:
raise ValueError(f"Invalid tolerance on gradient: {tolerance_grad}")
defaults = {"lr": lr, "tolerance_grad": tolerance_grad, "method": method}
super().__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError(
"SS Broyden/BFGS doesn't support per-parameter options"
" (parameter groups)"
)
if method not in ["ssbfgs", "ssbroyden"]:
raise ValueError("method should be either ssbroyden or ssbfgs")
self._params = self.param_groups[0]["params"]
self._numel_cache = None
# number of parameters:
nbparams = self._numel()
state = self.state[self._params[0]]
state["k"] = 0
state["Hk"] = torch.eye(nbparams, dtype=torch.get_default_dtype())
def _numel(self):
if self._numel_cache is None:
self._numel_cache = sum(
2 * p.numel() if torch.is_complex(p) else p.numel()
for p in self._params
)
return self._numel_cache
def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.view(-1)
if torch.is_complex(view):
view = torch.view_as_real(view).view(-1)
views.append(view)
return torch.cat(views, 0)
def _add_grad(self, step_size, update):
offset = 0
for p in self._params:
if torch.is_complex(p):
p = torch.view_as_real(p)
numel = p.numel()
# view as to avoid deprecated pointwise semantics
p.add_(update[offset : offset + numel].view_as(p), alpha=step_size)
offset += numel
assert offset == self._numel()
def _clone_param(self):
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
def _flatten(self, params):
views = []
for p in params:
if p.is_sparse:
view = p.to_dense().view(-1)
else:
view = p.view(-1)
if torch.is_complex(view):
view = torch.view_as_real(view).view(-1)
views.append(view)
return torch.cat(views, 0)
def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.copy_(pdata)
def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
loss = float(closure())
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad
[docs]
@torch.no_grad()
def step(self, closure: Callable) -> torch.Tensor: # type: ignore[override]
"""Perform a single optimization step.
Args:
closure: A closure that reevaluates the model
and returns the loss.
Returns:
the initial loss
"""
assert len(self.param_groups) == 1
# Make sure the closure is always called with grad enabled
closure = torch.enable_grad()(closure)
group = self.param_groups[0]
lr = float(group["lr"])
tolerance_grad = group["tolerance_grad"]
method = group["method"]
# NOTE: SSBroyden has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
state = self.state[self._params[0]]
x_init = self._clone_param()
theta_k = self._flatten(x_init)
def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
# evaluate initial f(x) and df/dx
orig_loss = closure()
loss = float(orig_loss)
grad_k = self._gather_flat_grad()
opt_cond = grad_k.abs().max() <= tolerance_grad
# print("grad_k: ", grad_k, ", opt_cond: ", opt_cond)
if opt_cond:
return orig_loss
# descent direction
prec_grad = state["Hk"] @ grad_k
prec_grad = prec_grad.neg()
# directional derivative
gtd = grad_k.dot(prec_grad)
# stepsize
loss, grad_kp1, alpha_k, ls_func_evals = _strong_wolfe(
obj_func, x_init, lr, prec_grad, loss, grad_k, gtd
)
# print("alpha_k: ", alpha_k)
# opt_cond = grad_kp1.abs().max() <= tolerance_grad
# if opt_cond:
# return orig_loss
# print("grad_kp1: ", grad_kp1, ", opt_cond: ", opt_cond)
# assert isinstance(grad_kp1, torch.Tensor)
# print("alpha_k: ", alpha_k)
# assert isinstance(alpha_k, torch.Tensor)
# check that there are no nans or infs
if (
math.isnan(loss)
or math.isinf(loss)
or torch.isnan(alpha_k)
or torch.isinf(alpha_k)
or torch.any(torch.isnan(grad_kp1))
or torch.any(torch.isinf(grad_kp1))
):
orig_loss = closure()
return orig_loss
# optimized parameters
self._add_grad(alpha_k, prec_grad)
# compute some values for next turn:
theta_kp1 = self._flatten(self._clone_param())
s_k = theta_kp1 - theta_k
# print(s_k)
y_k = grad_kp1 - grad_k
# print(y_k)
Hkyk = state["Hk"] @ y_k
yk_dot_Hkyk = y_k @ Hkyk
yk_dot_sk = y_k @ s_k
v_k = torch.sqrt(yk_dot_Hkyk) * (s_k / (yk_dot_sk) - Hkyk / yk_dot_Hkyk)
# method ssbfgs
tau_k = min(1.0, -yk_dot_sk / (alpha_k * (s_k @ grad_k)))
phi_k = 1.0
if method == "ssbroyden":
# raise NotImplementedError
b_k = -alpha_k * (s_k @ grad_k) / yk_dot_sk
h_k = yk_dot_Hkyk / yk_dot_sk
a_k = h_k * b_k - 1.0
c_k = torch.sqrt(a_k / (a_k + 1.0))
rhom_k = min(1.0, h_k * (1 - c_k))
thetam_k = (rhom_k - 1) / a_k
thetap_k = 1.0 / rhom_k
theta_k = max(thetam_k, min(thetap_k, (1.0 - b_k) / b_k))
sigma_k = 1 + a_k * theta_k
n = self._numel()
sigma_k_pow = sigma_k ** (-1 / (n - 1))
if theta_k > 0:
tau_k = tau_k * min(sigma_k_pow, 1.0 / theta_k)
else:
tau_k = min(tau_k * sigma_k_pow, sigma_k)
phi_k = (1 - theta_k) / (1.0 + a_k * theta_k)
# print("tau_k: ", tau_k)
temp1 = (Hkyk[:, None] @ Hkyk[None, :]) / yk_dot_Hkyk
temp2 = phi_k * (v_k[:, None] @ v_k[None, :])
temp3 = (s_k[:, None] @ s_k[None, :]) / yk_dot_sk
H_kp1 = (1 / tau_k) * (state["Hk"] - temp1 + temp2) + temp3
# print("H_kp1: ", H_kp1)
if torch.any(torch.isnan(H_kp1)):
orig_loss = closure()
return orig_loss
state["Hk"] = H_kp1
state["k"] += 1
return orig_loss