Source code for scimba_torch.optimizers.optimizers_data

"""A module to handle several optimizers.

Examples:
**Optimizers usage**

    .. code-block:: python

        import math

        from scimba_torch.optimizers.optimizers_data import OptimizerData
        from scimba_torch.optimizers.scimba_optimizers import ScimbaMomentum

        opt_1 = {
            "name": "adam",
            "optimizer_args": {"lr": 1e-3, "betas": (0.9, 0.999)},
        }

        opt_2 = {"class": ScimbaMomentum, "switch_at_epoch": 500}

        opt_3 = {
            "name": "lbfgs",
            "switch_at_epoch_ratio": 0.7,
            "switch_at_plateau": [500, 20],
            "switch_at_plateau_ratio": 3.0,
        }

        optimizers = OptimizerData(opt_1, opt_2, opt_3)

        print("optimizers: ", optimizers)

        class SimpleNN(torch.nn.Module):

            def __init__(self):
                super(SimpleNN, self).__init__()
                self.fc1 = torch.nn.Linear(10, 1)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.fc1(x)

        net = SimpleNN()

        opt_1 = {"name": "adam", "optimizer_args": {"lr": 1e-3, "betas": (0.9, 0.999)}}
        opt_2 = {
            "name": "adam",
            "optimizer_args": {"lrTEST": 1e-3, "betasTEST": (0.9, 0.999)},
        }  # wrong list of arguments

        try:
            optimizers2 = OptimizerData(opt_1, opt_2)
            optimizers2.activate_first_optimizer(list(net.parameters()))
        except TypeError as error:
            print(error)

        input_tensor = torch.randn(10000, 10)  # 10000 samples, each of size 10
        target_tensor = torch.sum(input_tensor, dim=1)[
            :, None
        ]  # Target tensor with batch size 10000

        loss = [torch.tensor(float("+inf"))]
        loss_func = torch.nn.MSELoss()

        opt = optimizers
        opt.activate_first_optimizer(list(net.parameters()))

        def closure():
            opt.zero_grad()
            output_tensor = net(input_tensor)
            loss[0] = loss_func(output_tensor, target_tensor)
            loss[0].backward(retain_graph=True)
            return loss[0].item()

        init_loss = closure()

        grads = opt.get_opt_gradients()
        print("get_opt_gradients: ", grads)

        loss_history = [init_loss]
        best_loss = init_loss
        best_net = copy.deepcopy(net.state_dict())

        epochs = 1000
        for epoch in range(epochs):
            opt.step(closure)

            if math.isinf(loss[0].item()) or math.isnan(loss[0].item()):
                loss[0] = torch.tensor(best_loss)
                net.load_state_dict(best_net)

            if loss[0].item() < best_loss:
                best_loss = loss[0].item()
                best_net = copy.deepcopy(net.state_dict())
                opt.update_best_optimizer()

            loss_history.append(loss[0].item())

            if epoch % 100 == 0:
                print("epoch: ", epoch, "loss: ", loss[0].item())

            if opt.test_activate_next_optimizer(
                loss_history, loss[0].item(), init_loss, epoch, epochs
            ):
                print("activate next opt! epoch = ", epoch)
                opt.activate_next_optimizer(list(net.parameters()))

            opt.test_and_activate_next_optimizer(
                list(net.parameters()),
                loss_history,
                loss[0].item(),
                init_loss,
                epoch,
                epochs,
            )

        net.load_state_dict(best_net)
        closure()
        print("loss after training: ", loss[0].item())
        print("net( torch.ones( 10 ) ) : ", net(torch.ones(10)))

        grads = opt.get_opt_gradients()
        print("get_opt_gradients: ", grads)
"""

import operator  # for xor
import warnings
from typing import Any, Callable

import torch
from torch.optim.optimizer import ParamsT

from scimba_torch.optimizers.scimba_optimizers import (
    AbstractScimbaOptimizer,
    ScimbaAdam,
    ScimbaCustomOptomizer,
    ScimbaLBFGS,
    ScimbaSGD,
    ScimbaSSBFGS,
    ScimbaSSBroyden,
)

OPT_DICT_TYPE = dict[str, Any]

SCIMBA_OPTIMIZERS = {
    "adam": ScimbaAdam,
    "lbfgs": ScimbaLBFGS,
    "ssbfgs": ScimbaSSBFGS,
    "ssbroyden": ScimbaSSBroyden,
    "sgd": ScimbaSGD,
}

SWITCH_AT_EPOCH_DEFAULT = False
SWITCH_AT_EPOCH_N_DEFAULT = 5000
SWITCH_AT_EPOCH_RATIO_DEFAULT = 0.7
SWITCH_AT_PLATEAU_DEFAULT = False
SWITCH_AT_PLATEAU_N1_N2_DEFAULT = (50, 10)
SWITCH_AT_PLATEAU_RATIO_DEFAULT = 500.0


[docs] class OptimizerData: r"""A class to manage multiple optimizers and their activation criteria. Args: *args: Variable length argument list of optimizer configurations. Input dictionary must have one of the form: { "class": value (a subclass of AbstractScimbaOptimizer), keys: value } { "name": value (either "adam" or "lbfgs"), keys: value } where pairs keys value can be: "optimizer_args": a dictionary of arguments for the optimizer, "scheduler": a subclass of torch.optim.lr_scheduler.LRScheduler, "scheduler_args: a dictionary of arguments for the scheduler "switch_at_epoch": a bool or an int, default false, if true then default value 5000 is used "switch_at_epoch_ratio": a bool or a float, default 0.7, if true then default value is used "switch_at_plateau": a bool or a tuple of two int, default False, if True then default (50, 10) is used "switch_at_plateau_ratio": a float r, default value 500.; triggers the plateau tests if current_loss < init_loss/r **kwargs: Arbitrary keyword arguments. Examples: >>> from scimba_torch.optimizers.scimba_optimizers\ ... import ScimbaMomentum >>> opt_1 = {\ ... "name": "adam",\ ... "optimizer_args": {"lr": 1e-3, "betas": (0.9, 0.999)},\ ... } >>> opt_2 = {"class": ScimbaMomentum, "switch_at_epoch": 500} >>> opt_3 = {\ ... "name": "lbfgs",\ ... "switch_at_epoch_ratio": 0.7,\ ... "switch_at_plateau": [500, 20],\ ... "switch_at_plateau_ratio": 3.0,\ ... } >>> optimizers = OptimizerData(opt_1, opt_2, opt_3) """ activated_optimizer: list[AbstractScimbaOptimizer] """A list containing the current optimizer; empty if none.""" def __init__(self, *args: OPT_DICT_TYPE, **kwargs): self.activated_optimizer = [] #: A list containing the current optimizer; empty if none. self.optimizers: list[OPT_DICT_TYPE] = [] #: List of optimizers. self.next_optimizer: int = 0 #: Index of the next optimizer to be activated. # in case of lr modified by linesearch, need to remenber initial lr # self.default_lr: float = kwargs.get("default_lr", 1e-2) for opt in args: self._check_dictionary_and_append_to_optimizers_list(opt) if len(self.optimizers) == 0: # default optimizers self.optimizers.append({"name": "adam"}) self.optimizers.append({"name": "lbfgs"}) def _check_dictionary_and_append_to_optimizers_list(self, opt: OPT_DICT_TYPE): """Checks the input configuration and appends it to the optimizers list. Args: opt: Optimizer configuration dictionary. Raises: KeyError: If the dictionary does not contain exactly one of "name" or "class". ValueError: If the dictionary contains invalid values for any of the keys. """ # Check that there is either a "name" or a "class" key if not operator.xor(("name" in opt), ("class" in opt)): raise KeyError( f"Cannot create a {self.__class__.__name__} from dict: {opt}" ) # Check types and values of entries for key in opt: if key == "name": opt_name = opt["name"] if opt_name not in SCIMBA_OPTIMIZERS: raise ValueError( f"Cannot create a {self.__class__.__name__} \ from name: {opt_name}" ) elif key == "class": opt_class = opt["class"] if not issubclass(opt_class, ScimbaCustomOptomizer): raise ValueError( f"Cannot create a {self.__class__.__name__} \ from class: {opt_class}; it must be a subclass \ of AbstractOptimizer" ) elif key == "optimizer_args": opt_args = opt[key] if not isinstance(opt_args, dict): raise ValueError( f"Cannot create a {self.__class__.__name__} from given {key}" ) elif key == "scheduler": opt_args = opt[key] if not issubclass(opt_args, torch.optim.lr_scheduler.LRScheduler): raise ValueError( f"Cannot create a {self.__class__.__name__} \ from given {key}; it must be a subclass of \ torch.optim.lr_scheduler.LRScheduler" ) elif key == "scheduler_args": opt_args = opt[key] if not isinstance(opt_args, dict): raise ValueError( f"Cannot create a {self.__class__.__name__} from given {key}" ) elif key == "switch_at_epoch": opt_args = opt[key] if not (isinstance(opt_args, bool) or isinstance(opt_args, int)): raise ValueError( f"In {self.__class__.__name__}.init: {key} \ must be either a bool or an int" ) elif key == "switch_at_epoch_ratio": opt_args = opt[key] if not (isinstance(opt_args, bool) or isinstance(opt_args, float)): raise ValueError( f"In {self.__class__.__name__}.init: {key} \ must be either a bool or a float" ) elif key == "switch_at_plateau": opt_args = opt[key] if not (isinstance(opt_args, bool) or isinstance(opt_args, list)): raise ValueError( f"In {self.__class__.__name__}.init: {key} \ must be either a bool or tuple of two ints" ) elif key == "switch_at_plateau_ratio": opt_args = opt[key] if not isinstance(opt_args, float): raise ValueError( f"In {self.__class__.__name__}.init: {key} must be a float" ) else: warnings.warn( f"In {self.__class__.__name__}.init: unrecognized option {key}", UserWarning, ) self.optimizers.append(opt.copy())
[docs] def step(self, closure: Callable[[], float]) -> None: """Performs an optimization step using the currently activated optimizer. Args: closure: A closure that reevaluates the model and returns the loss. """ if len(self.activated_optimizer) == 1: self.activated_optimizer[0].optimizer_step(closure)
[docs] def set_lr(self, lr: float) -> None: """Set learning rate of activated optimizer. Args: lr: The new learning rate. """ if len(self.activated_optimizer) == 1: for group in self.activated_optimizer[0].param_groups: group["lr"] = lr
[docs] def zero_grad(self) -> None: """Zeros the gradients of the currently activated optimizer.""" self.activated_optimizer[0].zero_grad()
[docs] def test_activate_next_optimizer( self, loss_history: list[float], loss_value: float, init_loss: float, epoch: int, epochs: int, ) -> bool: """Tests whether the next opt. should be activated based on the given criteria. Args: loss_history: History of loss values. loss_value: Current loss value. init_loss: Initial loss value. epoch: Current epoch. epochs: Total number of epochs. Returns: True if the next optimizer should be activated, False otherwise. """ if self.next_optimizer >= len(self.optimizers): return False next_opt = self.optimizers[self.next_optimizer] switch_if_epoch = next_opt.get( "switch_at_epoch", SWITCH_AT_EPOCH_DEFAULT ) # default value = False if isinstance(switch_if_epoch, bool): n = SWITCH_AT_EPOCH_N_DEFAULT else: n = switch_if_epoch switch_if_epoch = True switch_if_epoch_ratio = True switch_at_epoch_ratio = next_opt.get( "switch_at_epoch_ratio", SWITCH_AT_EPOCH_RATIO_DEFAULT ) if isinstance(switch_at_epoch_ratio, bool): switch_if_epoch_ratio = switch_at_epoch_ratio switch_at_epoch_ratio = SWITCH_AT_EPOCH_RATIO_DEFAULT switch_if_plateau = next_opt.get("switch_at_plateau", SWITCH_AT_PLATEAU_DEFAULT) if isinstance(switch_if_plateau, bool): n1, n2 = SWITCH_AT_PLATEAU_N1_N2_DEFAULT else: n1, n2 = switch_if_plateau switch_if_plateau = True switch_at_plateau_ratio = next_opt.get( "switch_at_plateau_ratio", SWITCH_AT_PLATEAU_RATIO_DEFAULT ) if (switch_if_epoch) and (epoch >= n): return True if (switch_if_epoch_ratio) and (epoch / epochs > switch_at_epoch_ratio): return True if switch_if_plateau: if (loss_value < (init_loss / switch_at_plateau_ratio)) and ( sum(loss_history[-n2:-1]) - sum(loss_history[-n1 : -n1 + n2]) > 0 ): return True return False
[docs] def activate_next_optimizer( self, parameters: ParamsT, verbose: bool = False ) -> None: """Activates the next optimizer in the list. Args: parameters: Parameters to be optimized. verbose: whether to print activation message or not. """ if self.next_optimizer >= len(self.optimizers): warnings.warn( "trying to overflow list of optimizers - nothing will happen", RuntimeWarning, ) return opt = self.optimizers[self.next_optimizer] if "name" in opt: opt_name = opt["name"] opt_class = SCIMBA_OPTIMIZERS[opt_name] else: opt_class = opt["class"] # Prepare the arguments arguments = {} if "optimizer_args" in opt: arguments["optimizer_args"] = opt["optimizer_args"].copy() if "scheduler" in opt: arguments["scheduler"] = opt["scheduler"] if "scheduler_args" in opt: arguments["scheduler_args"] = opt["scheduler_args"].copy() if len(self.activated_optimizer) == 1: self.activated_optimizer[0] = opt_class(parameters, **arguments) else: self.activated_optimizer.append(opt_class(parameters, **arguments)) if verbose: print(f"activating optimizer {opt_class.__name__}") self.next_optimizer += 1
[docs] def activate_first_optimizer( self, parameters: ParamsT, verbose: bool = False ) -> None: """Activates the first optimizer in the list. Args: parameters: Parameters to be optimized. verbose: whether to print activation message or not. """ # case where optimizers have already been activated: do nothing if len(self.activated_optimizer): return # try first to activate all the optimizers to report # possible errors at the begining of the training! next_optimizer_save = ( self.next_optimizer ) # in case where self was loaded from a file while self.next_optimizer < len(self.optimizers): self.activate_next_optimizer(parameters, verbose=False) self.activated_optimizer = [] self.next_optimizer = next_optimizer_save # activate the first optimizer self.activate_next_optimizer(parameters, verbose)
[docs] def test_and_activate_next_optimizer( self, parameters: ParamsT, loss_history: list[float], loss_value: float, init_loss: float, epoch: int, epochs: int, ) -> None: """Tests whether next optimizer should be activated; activates it. Args: parameters: Parameters to be optimized. loss_history: History of loss values. loss_value: Current loss value. init_loss: Initial loss value. epoch: Current epoch. epochs: Total number of epochs. """ if self.test_activate_next_optimizer( loss_history, loss_value, init_loss, epoch, epochs ): self.activate_next_optimizer(parameters)
[docs] def get_opt_gradients(self) -> torch.Tensor: """Gets the gradients of the currently activated optimizer. Returns: Flattened tensor of gradients. """ grads = torch.tensor([]) for p in self.activated_optimizer[0].param_groups[0]["params"]: if p.grad is not None: grads = torch.cat((grads, p.grad.flatten()[:, None]), 0) return grads
[docs] def update_best_optimizer(self) -> None: """Updates the best state of the currently activated optimizer.""" self.activated_optimizer[0].update_best_optimizer()
[docs] def dict_for_save(self) -> dict: """Returns a dictionary containing the best state of the current optimizer. Returns: dictionary containing the best state of the optimizer. """ res = self.activated_optimizer[0].dict_for_save() res["next_optimizer"] = self.next_optimizer return res
[docs] def load_from_dict(self, parameters: ParamsT, checkpoint: dict) -> None: """Loads the optimizer and scheduler states from a checkpoint. Args: parameters: Parameters to be optimized. checkpoint: dictionary containing the optimizer and scheduler states. Raises: ValueError: when there is no active optimizer to load in. """ self.next_optimizer = checkpoint["next_optimizer"] self.next_optimizer = self.next_optimizer - 1 self.activate_next_optimizer(parameters) if len(self.activated_optimizer) < 1: raise ValueError("there is no active optimizer to load in!") self.activated_optimizer[0].load(checkpoint)
def __str__(self) -> str: """Returns a string representation of the optimizers. Returns: str: String representation of the optimizers. """ ret = "optimizers: [" for opt in self.optimizers: try: opt_name = opt["name"] ret = ret + opt_name + ", " except KeyError: opt_class = opt["class"] ret = ret + opt_class.__name__ + ", " ret = ret + "]" return ret