"""A module to handle several optimizers."""
import copy
import operator # for xor
import warnings
from typing import Any, Callable
import torch
from torch.optim.optimizer import ParamsT
from scimba_torch.optimizers.scimba_optimizers import (
AbstractScimbaOptimizer,
ScimbaAdam,
ScimbaCustomOptomizer,
ScimbaLBFGS,
ScimbaSGD,
)
SCIMBA_OPTIMIZERS = {"adam": ScimbaAdam, "lbfgs": ScimbaLBFGS, "sgd": ScimbaSGD}
SWITCH_AT_EPOCH_DEFAULT = False
SWITCH_AT_EPOCH_N_DEFAULT = 5000
SWITCH_AT_EPOCH_RATIO_DEFAULT = 0.7
SWITCH_AT_PLATEAU_DEFAULT = False
SWITCH_AT_PLATEAU_N1_N2_DEFAULT = (50, 10)
SWITCH_AT_PLATEAU_RATIO_DEFAULT = 500.0
[docs]
class OptimizerData:
r"""A class to manage multiple optimizers and their activation criteria.
Args:
*args: Variable length argument list of optimizer configurations.
Input dictionary must have one of the form:
{ "class": value (a subclass of AbstractScimbaOptimizer), keys: value }
{ "name": value (either "adam" or "lbfgs"), keys: value }
where pairs keys value can be:
"optimizer_args": a dictionary of arguments for the optimizer,
"scheduler": a subclass of torch.optim.lr_scheduler.LRScheduler,
"scheduler_args: a dictionary of arguments for the scheduler
"switch_at_epoch": a bool or an int, default false,
if true then default value 5000 is used
"switch_at_epoch_ratio": a bool or a float, default 0.7,
if true then default value is used
"switch_at_plateau": a bool or a tuple of two int, default False,
if True then default (50, 10) is used
"switch_at_plateau_ratio": a float r, default value 500.;
triggers the plateau tests if current_loss < init_loss/r
**kwargs: Arbitrary keyword arguments.
Examples:
>>> from scimba_torch.optimizers.scimba_optimizers\
... import ScimbaMomentum
>>> opt_1 = {\
... "name": "adam",\
... "optimizer_args": {"lr": 1e-3, "betas": (0.9, 0.999)},\
... }
>>> opt_2 = {"class": ScimbaMomentum, "switch_at_epoch": 500}
>>> opt_3 = {\
... "name": "lbfgs",\
... "switch_at_epoch_ratio": 0.7,\
... "switch_at_plateau": [500, 20],\
... "switch_at_plateau_ratio": 3.0,\
... }
>>> optimizers = OptimizerData(opt_1, opt_2, opt_3)
"""
activated_optimizer: list[AbstractScimbaOptimizer]
"""A list containing the current optimizer; empty if none."""
def __init__(self, *args: dict[str, Any], **kwargs):
self.activated_optimizer = []
#: A list containing the current optimizer; empty if none.
self.optimizers: list[dict[str, Any]] = [] #: List of optimizers.
self.next_optimizer: int = 0 #: Index of the next optimizer to be activated.
# in case of lr modified by linesearch, need to remenber initial lr
# self.default_lr: float = kwargs.get("default_lr", 1e-2)
for opt in args:
self._check_dictionary_and_append_to_optimizers_list(opt)
if len(self.optimizers) == 0: # default optimizers
self.optimizers.append({"name": "adam"})
self.optimizers.append({"name": "lbfgs"})
def _check_dictionary_and_append_to_optimizers_list(self, opt: dict[str, Any]):
"""Checks the input configuration and appends it to the optimizers list.
Args:
opt: Optimizer configuration dictionary.
Raises:
KeyError: If the dictionary does not contain
exactly one of "name" or "class".
ValueError: If the dictionary contains invalid values for any of the keys.
"""
# Check that there is either a "name" or a "class" key
if not operator.xor(("name" in opt), ("class" in opt)):
raise KeyError(
f"Cannot create a {self.__class__.__name__} from dict: {opt}"
)
# Check types and values of entries
for key in opt:
if key == "name":
opt_name = opt["name"]
if opt_name not in SCIMBA_OPTIMIZERS:
raise ValueError(
f"Cannot create a {self.__class__.__name__} \
from name: {opt_name}"
)
elif key == "class":
opt_class = opt["class"]
if not issubclass(opt_class, ScimbaCustomOptomizer):
raise ValueError(
f"Cannot create a {self.__class__.__name__} \
from class: {opt_class}; it must be a subclass \
of AbstractOptimizer"
)
elif key == "optimizer_args":
opt_args = opt[key]
if not isinstance(opt_args, dict):
raise ValueError(
f"Cannot create a {self.__class__.__name__} from given {key}"
)
elif key == "scheduler":
opt_args = opt[key]
if not issubclass(opt_args, torch.optim.lr_scheduler.LRScheduler):
raise ValueError(
f"Cannot create a {self.__class__.__name__} \
from given {key}; it must be a subclass of \
torch.optim.lr_scheduler.LRScheduler"
)
elif key == "scheduler_args":
opt_args = opt[key]
if not isinstance(opt_args, dict):
raise ValueError(
f"Cannot create a {self.__class__.__name__} from given {key}"
)
elif key == "switch_at_epoch":
opt_args = opt[key]
if not (isinstance(opt_args, bool) or isinstance(opt_args, int)):
raise ValueError(
f"In {self.__class__.__name__}.init: {key} \
must be either a bool or an int"
)
elif key == "switch_at_epoch_ratio":
opt_args = opt[key]
if not (isinstance(opt_args, bool) or isinstance(opt_args, float)):
raise ValueError(
f"In {self.__class__.__name__}.init: {key} \
must be either a bool or a float"
)
elif key == "switch_at_plateau":
opt_args = opt[key]
if not (isinstance(opt_args, bool) or isinstance(opt_args, list)):
raise ValueError(
f"In {self.__class__.__name__}.init: {key} \
must be either a bool or tuple of two ints"
)
elif key == "switch_at_plateau_ratio":
opt_args = opt[key]
if not isinstance(opt_args, float):
raise ValueError(
f"In {self.__class__.__name__}.init: {key} must be a float"
)
else:
warnings.warn(
f"In {self.__class__.__name__}.init: unrecognized option {key}",
UserWarning,
)
self.optimizers.append(opt.copy())
[docs]
def step(self, closure: Callable[[], float]) -> None:
"""Performs an optimization step using the currently activated optimizer.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
if len(self.activated_optimizer) == 1:
self.activated_optimizer[0].optimizer_step(closure)
[docs]
def set_lr(self, lr: float) -> None:
"""Set learning rate of activated optimizer.
Args:
lr: The new learning rate.
"""
if len(self.activated_optimizer) == 1:
for group in self.activated_optimizer[0].param_groups:
group["lr"] = lr
[docs]
def zero_grad(self) -> None:
"""Zeros the gradients of the currently activated optimizer."""
self.activated_optimizer[0].zero_grad()
[docs]
def test_activate_next_optimizer(
self,
loss_history: list[float],
loss_value: float,
init_loss: float,
epoch: int,
epochs: int,
) -> bool:
"""Tests whether the next opt. should be activated based on the given criteria.
Args:
loss_history: History of loss values.
loss_value: Current loss value.
init_loss: Initial loss value.
epoch: Current epoch.
epochs: Total number of epochs.
Returns:
True if the next optimizer should be activated, False otherwise.
"""
if self.next_optimizer >= len(self.optimizers):
return False
next_opt = self.optimizers[self.next_optimizer]
switch_if_epoch = next_opt.get(
"switch_at_epoch", SWITCH_AT_EPOCH_DEFAULT
) # default value = False
if isinstance(switch_if_epoch, bool):
n = SWITCH_AT_EPOCH_N_DEFAULT
else:
n = switch_if_epoch
switch_if_epoch = True
switch_if_epoch_ratio = True
switch_at_epoch_ratio = next_opt.get(
"switch_at_epoch_ratio", SWITCH_AT_EPOCH_RATIO_DEFAULT
)
if isinstance(switch_at_epoch_ratio, bool):
switch_if_epoch_ratio = switch_at_epoch_ratio
switch_at_epoch_ratio = SWITCH_AT_EPOCH_RATIO_DEFAULT
switch_if_plateau = next_opt.get("switch_at_plateau", SWITCH_AT_PLATEAU_DEFAULT)
if isinstance(switch_if_plateau, bool):
n1, n2 = SWITCH_AT_PLATEAU_N1_N2_DEFAULT
else:
n1, n2 = switch_if_plateau
switch_if_plateau = True
switch_at_plateau_ratio = next_opt.get(
"switch_at_plateau_ratio", SWITCH_AT_PLATEAU_RATIO_DEFAULT
)
if (switch_if_epoch) and (epoch >= n):
return True
if (switch_if_epoch_ratio) and (epoch / epochs > switch_at_epoch_ratio):
return True
if switch_if_plateau:
if (loss_value < (init_loss / switch_at_plateau_ratio)) and (
sum(loss_history[-n2:-1]) - sum(loss_history[-n1 : -n1 + n2]) > 0
):
return True
return False
[docs]
def activate_next_optimizer(
self, parameters: ParamsT, verbose: bool = False
) -> None:
"""Activates the next optimizer in the list.
Args:
parameters: Parameters to be optimized.
verbose: whether to print activation message or not.
"""
if self.next_optimizer >= len(self.optimizers):
warnings.warn(
"trying to overflow list of optimizers - nothing will happen",
RuntimeWarning,
)
return
opt = self.optimizers[self.next_optimizer]
if "name" in opt:
opt_name = opt["name"]
opt_class = SCIMBA_OPTIMIZERS[opt_name]
else:
opt_class = opt["class"]
# Prepare the arguments
arguments = {}
if "optimizer_args" in opt:
arguments["optimizer_args"] = opt["optimizer_args"].copy()
if "scheduler" in opt:
arguments["scheduler"] = opt["scheduler"]
if "scheduler_args" in opt:
arguments["scheduler_args"] = opt["scheduler_args"].copy()
if len(self.activated_optimizer) == 1:
self.activated_optimizer[0] = opt_class(parameters, **arguments)
else:
self.activated_optimizer.append(opt_class(parameters, **arguments))
if verbose:
print(f"activating optimizer {opt_class.__name__}")
self.next_optimizer += 1
[docs]
def activate_first_optimizer(
self, parameters: ParamsT, verbose: bool = False
) -> None:
"""Activates the first optimizer in the list.
Args:
parameters: Parameters to be optimized.
verbose: whether to print activation message or not.
"""
# case where optimizers have already been activated: do nothing
if len(self.activated_optimizer):
return
# try first to activate all the optimizers to report
# possible errors at the begining of the training!
next_optimizer_save = (
self.next_optimizer
) # in case where self was loaded from a file
while self.next_optimizer < len(self.optimizers):
self.activate_next_optimizer(parameters, verbose=False)
self.activated_optimizer = []
self.next_optimizer = next_optimizer_save
# activate the first optimizer
self.activate_next_optimizer(parameters, verbose)
[docs]
def test_and_activate_next_optimizer(
self,
parameters: ParamsT,
loss_history: list[float],
loss_value: float,
init_loss: float,
epoch: int,
epochs: int,
) -> None:
"""Tests whether next optimizer should be activated; activates it.
Args:
parameters: Parameters to be optimized.
loss_history: History of loss values.
loss_value: Current loss value.
init_loss: Initial loss value.
epoch: Current epoch.
epochs: Total number of epochs.
"""
if self.test_activate_next_optimizer(
loss_history, loss_value, init_loss, epoch, epochs
):
self.activate_next_optimizer(parameters)
[docs]
def get_opt_gradients(self) -> torch.Tensor:
"""Gets the gradients of the currently activated optimizer.
Returns:
Flattened tensor of gradients.
"""
grads = torch.tensor([])
for p in self.activated_optimizer[0].param_groups[0]["params"]:
if p.grad is not None:
grads = torch.cat((grads, p.grad.flatten()[:, None]), 0)
return grads
[docs]
def update_best_optimizer(self) -> None:
"""Updates the best state of the currently activated optimizer."""
self.activated_optimizer[0].update_best_optimizer()
[docs]
def dict_for_save(self) -> dict:
"""Returns a dictionary containing the best state of the current optimizer.
Returns:
dictionary containing the best state of the optimizer.
"""
res = self.activated_optimizer[0].dict_for_save()
res["next_optimizer"] = self.next_optimizer
return res
[docs]
def load_from_dict(self, parameters: ParamsT, checkpoint: dict) -> None:
"""Loads the optimizer and scheduler states from a checkpoint.
Args:
parameters: Parameters to be optimized.
checkpoint: dictionary containing the optimizer and scheduler states.
Raises:
ValueError: when there is no active optimizer to load in.
"""
self.next_optimizer = checkpoint["next_optimizer"]
self.next_optimizer = self.next_optimizer - 1
self.activate_next_optimizer(parameters)
if len(self.activated_optimizer) < 1:
raise ValueError("there is no active optimizer to load in!")
self.activated_optimizer[0].load(checkpoint)
def __str__(self) -> str:
"""Returns a string representation of the optimizers.
Returns:
str: String representation of the optimizers.
"""
ret = "optimizers: ["
for opt in self.optimizers:
try:
opt_name = opt["name"]
ret = ret + opt_name + ", "
except KeyError:
opt_class = opt["class"]
ret = ret + opt_class.__name__ + ", "
ret = ret + "]"
return ret
if __name__ == "__main__": # pragma: no cover
import math
from scimba_torch.optimizers.scimba_optimizers import ScimbaMomentum
opt_1 = {
"name": "adam",
"optimizer_args": {"lr": 1e-3, "betas": (0.9, 0.999)},
}
opt_2 = {"class": ScimbaMomentum, "switch_at_epoch": 500}
opt_3 = {
"name": "lbfgs",
"switch_at_epoch_ratio": 0.7,
"switch_at_plateau": [500, 20],
"switch_at_plateau_ratio": 3.0,
}
optimizers = OptimizerData(opt_1, opt_2, opt_3)
print("optimizers: ", optimizers)
class SimpleNN(torch.nn.Module):
"""For test."""
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = torch.nn.Linear(10, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""For test.
Args:
x: For test.
Returns:
For test.
"""
return self.fc1(x)
net = SimpleNN()
opt_1 = {"name": "adam", "optimizer_args": {"lr": 1e-3, "betas": (0.9, 0.999)}}
opt_2 = {
"name": "adam",
"optimizer_args": {"lrTEST": 1e-3, "betasTEST": (0.9, 0.999)},
} # wrong list of arguments
try:
optimizers2 = OptimizerData(opt_1, opt_2)
optimizers2.activate_first_optimizer(list(net.parameters()))
except TypeError as error:
print(error)
input_tensor = torch.randn(10000, 10) # Batch of 10000 samples, each of size 10
target_tensor = torch.sum(input_tensor, dim=1)[
:, None
] # Target tensor with batch size 10000
loss = [torch.tensor(float("+inf"))]
loss_func = torch.nn.MSELoss()
opt = optimizers
opt.activate_first_optimizer(list(net.parameters()))
def closure():
"""For test.
Returns:
For test.
"""
opt.zero_grad()
# Forward pass
output_tensor = net(input_tensor)
loss[0] = loss_func(output_tensor, target_tensor)
loss[0].backward(retain_graph=True)
return loss[0].item()
init_loss = closure()
# grads = opt.get_opt_gradients()
# print("get_opt_gradients: ", grads)
loss_history = [init_loss]
best_loss = init_loss
best_net = copy.deepcopy(net.state_dict())
epochs = 1000
for epoch in range(epochs):
opt.step(closure)
if math.isinf(loss[0].item()) or math.isnan(loss[0].item()):
loss[0] = torch.tensor(best_loss)
net.load_state_dict(best_net)
if loss[0].item() < best_loss:
best_loss = loss[0].item()
best_net = copy.deepcopy(net.state_dict())
opt.update_best_optimizer()
loss_history.append(loss[0].item())
if epoch % 100 == 0:
print("epoch: ", epoch, "loss: ", loss[0].item())
if opt.test_activate_next_optimizer(
loss_history, loss[0].item(), init_loss, epoch, epochs
):
print("activate next opt! epoch = ", epoch)
opt.activate_next_optimizer(list(net.parameters()))
# opt.test_and_activate_next_optimizer(\
# list(net.parameters()), loss_history,\
# loss[0].item(), init_loss, epoch, epochs )
net.load_state_dict(best_net)
closure()
print("loss after training: ", loss[0].item())
print("net( torch.ones( 10 ) ) : ", net(torch.ones(10)))
# grads = opt.get_opt_gradients()
# print("get_opt_gradients: ", grads)