"""A module defining scimba optimizers."""
import copy
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable
import torch
from torch.optim.optimizer import ParamsT
[docs]
class NoScheduler:
"""A placeholder class to indicate the absence of a scheduler."""
[docs]
class AbstractScimbaOptimizer(torch.optim.Optimizer, ABC):
"""Abstract base class for Scimba optimizers with optional learning rate scheduler.
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
optimizer_args: Additional arguments for the optimizer. Defaults to {}.
scheduler: Learning rate scheduler class. Defaults to NoScheduler.
scheduler_args: Additional arguments for the scheduler. Defaults to {}.
**kwargs: Arbitrary keyword arguments.
Raises:
ValueError: scheduler is not an object of a
subclass of torch.optim.lr_scheduler.LRScheduler.
Attributes:
scheduler_exists: Flag indicating if a scheduler is set.
scheduler: list containing the scheduler.
best_optimizer: dictionary containing the best state of the optimizer.
best_scheduler: list containing the best state of the scheduler.
"""
scheduler_exists: bool
scheduler: list[torch.optim.lr_scheduler.LRScheduler]
best_optimizer: dict
best_scheduler: list[dict]
def __init__(
self,
params: ParamsT,
optimizer_args: dict[str, Any] = {},
scheduler: type = NoScheduler,
scheduler_args: dict[str, Any] = {},
**kwargs,
) -> None:
# initializes torch.optim.Optimizer part from params and optimizer_args
super().__init__(params, **optimizer_args)
# initializes scheduler if given
self.scheduler_exists = not (scheduler == NoScheduler)
if self.scheduler_exists:
if not issubclass(scheduler, torch.optim.lr_scheduler.LRScheduler):
raise ValueError(
f"Cannot create a {self.__class__.__name__} from scheduler class:\
{scheduler}; it must be a subclass of\
torch.optim.lr_scheduler.LRScheduler"
)
self.scheduler = [scheduler(self, **scheduler_args)]
else:
self.scheduler = []
self.best_optimizer = copy.deepcopy(self.state_dict())
if self.scheduler_exists:
self.best_scheduler = [copy.deepcopy(self.scheduler[0].state_dict())]
else:
self.best_scheduler = []
[docs]
def optimizer_step(self, closure: Callable[[], float]) -> None:
"""Performs an optimization step and updates the scheduler if it exists.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
self.inner_step(closure)
if self.scheduler_exists:
self.scheduler[0].step()
[docs]
@abstractmethod
def inner_step(self, closure: Callable[[], float]) -> None: # pragma: no cover
"""Abstract method for performing the inner optimization step.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
pass
[docs]
def update_best_optimizer(self) -> None:
"""Updates the best optimizer state."""
self.best_optimizer = copy.deepcopy(self.state_dict())
if self.scheduler_exists:
self.best_scheduler[0] = copy.deepcopy(self.scheduler[0].state_dict())
[docs]
def dict_for_save(self) -> dict:
"""Returns a dictionary containing the best optimizer and scheduler states.
Returns:
dict: dictionary containing the best optimizer and scheduler states.
"""
res = {"optimizer_state_dict": self.best_optimizer}
if self.scheduler_exists:
res["scheduler_state_dict"] = self.best_scheduler[0]
return res
[docs]
def load(self, checkpoint: dict) -> None:
"""Loads the optimizer and scheduler states from a checkpoint.
Args:
checkpoint: dictionary containing the optimizer and scheduler states.
"""
try:
self.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler_exists:
self.scheduler[0].load_state_dict(checkpoint["scheduler_state_dict"])
# except FileNotFoundError: #Rémi: ????????
except KeyError:
print("optimizer was not loaded from file: training needed")
[docs]
class ScimbaAdam(AbstractScimbaOptimizer, torch.optim.Adam):
"""Scimba wrapper for Adam optimizer with optional learning rate scheduler.
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
optimizer_args: Additional arguments for the Adam optimizer. Defaults to {}.
scheduler: Learning rate scheduler class.
Defaults to torch.optim.lr_scheduler.StepLR.
scheduler_args: Additional arguments for the scheduler. Defaults to {}.
**kwargs: Arbitrary keyword arguments.
"""
def __init__(
self,
params: ParamsT,
optimizer_args: dict[str, Any] = {},
scheduler: type = torch.optim.lr_scheduler.StepLR,
scheduler_args: dict[str, Any] = {},
**kwargs,
) -> None:
if scheduler == torch.optim.lr_scheduler.StepLR:
scheduler_args.setdefault("gamma", 0.99)
scheduler_args.setdefault("step_size", 20)
super().__init__(params, optimizer_args, scheduler, scheduler_args)
[docs]
def inner_step(self, closure: Callable[[], float]) -> None:
"""Performs the inner optimization step for ScimbaAdam.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
closure()
self.step()
[docs]
class ScimbaSGD(AbstractScimbaOptimizer, torch.optim.SGD):
"""Scimba wrapper for SGD optimizer with optional learning rate scheduler.
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
optimizer_args: Additional arguments for the Adam optimizer. Defaults to {}.
scheduler: Learning rate scheduler class.
Defaults to torch.optim.lr_scheduler.StepLR.
scheduler_args: Additional arguments for the scheduler. Defaults to {}.
**kwargs: Arbitrary keyword arguments.
"""
def __init__(
self,
params: ParamsT,
optimizer_args: dict[str, Any] = {},
scheduler: type = torch.optim.lr_scheduler.StepLR,
scheduler_args: dict[str, Any] = {},
**kwargs,
) -> None:
if scheduler == torch.optim.lr_scheduler.StepLR:
scheduler_args.setdefault("gamma", 0.99)
scheduler_args.setdefault("step_size", 20)
super().__init__(params, optimizer_args, scheduler, scheduler_args)
[docs]
def inner_step(self, closure: Callable[[], float]) -> None:
"""Performs the inner optimization step for ScimbaAdam.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
closure()
self.step()
[docs]
class ScimbaLBFGS(AbstractScimbaOptimizer, torch.optim.LBFGS):
"""Scimba wrapper for LBFGS optimizer with optional learning rate scheduler.
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
optimizer_args: Additional arguments for the LBFGS optimizer. Defaults to {}.
**kwargs: Arbitrary keyword arguments.
"""
def __init__(self, params: ParamsT, optimizer_args: dict[str, Any] = {}, **kwargs):
optimizer_args.setdefault("history_size", 15)
optimizer_args.setdefault("max_iter", 5)
optimizer_args.setdefault("line_search_fn", "strong_wolfe")
super().__init__(params, optimizer_args)
[docs]
def inner_step(self, closure: Callable[[], float]) -> None:
"""Performs the inner optimization step for ScimbaLBFGS.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
# super(AbstractScimbaOptimizer, self).step(closure)
self.step(closure)
[docs]
class ScimbaCustomOptomizer(AbstractScimbaOptimizer, ABC):
"""An abstract class of which user defined optimizer must inherit."""
[docs]
@abstractmethod
def step(self, closure: Callable[[], float]):
"""To be implemented in subclasses: applies one step of optimizer.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
[docs]
class ScimbaMomentum(ScimbaCustomOptomizer):
"""Custom Momentum optimizer with scheduler.
For an example of a custom optimizer inheriting from AbstractScimbaOptimizer.
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: learning rate
momentum: momentum
"""
def __init__(self, params: ParamsT, lr: float = 1e-3, momentum: float = 0.0):
super().__init__(
params,
optimizer_args={"defaults": {"lr": lr}},
scheduler=torch.optim.lr_scheduler.StepLR,
scheduler_args={"gamma": 0.99, "step_size": 10},
)
self.momentum = momentum
self.state = defaultdict(dict)
for group in self.param_groups:
for p in group["params"]:
self.state[p] = dict(mom=torch.zeros_like(p.data))
# this step method must be implemented in order for the scheduler to work properly
[docs]
def step(self, closure: Callable[[], float] | None = None):
"""Re-implements the step method.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
for group in self.param_groups:
for p in group["params"]:
mom = self.state[p]["mom"]
mom = self.momentum * mom - group["lr"] * p.grad.data
p.data += mom
[docs]
def inner_step(self, closure: Callable[[], float]) -> None:
"""The inner step method.
Args:
closure: A closure that reevaluates the model and returns the loss.
"""
closure()
self.step()
if __name__ == "__main__": # pragma: no cover
import math
class SimpleNN(torch.nn.Module):
"""For test."""
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = torch.nn.Linear(10, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""For test.
Args:
x: For test.
Returns:
For test.
"""
return self.fc1(x)
net = SimpleNN()
class DummyScheduler:
"""For test."""
try:
opt_test = ScimbaAdam(list(net.parameters()), scheduler=DummyScheduler)
except ValueError as error:
print(error)
# Batch of 10000 samples, each of size 10
input_tensor = torch.randn(10000, 10)
# learn sum function
# Target tensor with batch size 10000
target_tensor = torch.sum(input_tensor, dim=1)[:, None]
loss = [torch.tensor(float("+inf"))]
best_loss = float("+inf")
best_net = copy.deepcopy(net.state_dict())
opt = ScimbaAdam(list(net.parameters()))
# opt = ScimbaMomentum(list(net.parameters()))
loss_func = torch.nn.MSELoss()
# opt.zero_grad()
def closure() -> float:
"""For test.
Returns:
For test.
"""
opt.zero_grad()
# Forward pass
output_tensor = net(input_tensor)
loss[0] = loss_func(output_tensor, target_tensor)
loss[0].backward(retain_graph=True)
return loss[0].item()
# perform one step
# optimizer_step_count_before = opt._step_count
opt.optimizer_step(closure)
# optimizer_step_count_after = opt._step_count
# print("step_count before, after: ", optimizer_step_count_before, ", ",\
# optimizer_step_count_after)
epochs = 1000
for epoch in range(epochs):
opt.optimizer_step(closure)
if math.isinf(loss[0].item()) or math.isnan(loss[0].item()):
loss[0] = torch.tensor(best_loss)
net.load_state_dict(best_net)
if loss[0].item() < best_loss:
best_loss = loss[0].item()
best_net = copy.deepcopy(net.state_dict())
opt.update_best_optimizer()
if epoch % 100 == 0:
print("epoch: ", epoch, "loss: ", loss[0].item())
print("lr: ", opt.param_groups[0]["lr"])
net.load_state_dict(best_net)
closure()
print("loss after training: ", loss[0].item())
# optimizer_step_count_after = opt._step_count
# print("step_count after: ", optimizer_step_count_after)
print("\n")
# print("dict_for_save: ", opt.dict_for_save())
print("state_dict : ", opt.state_dict())
print("\n")
optt = ScimbaAdam(list(net.parameters()))
# print("dict_for_save: ", optt.dict_for_save())
print("state_dict : ", optt.state_dict())
print("\n")
optt.load(opt.dict_for_save())
# print("dict_for_save: ", optt.dict_for_save())
print("state_dict : ", optt.state_dict())
print("\n")
# print( "==", optt.state_dict() == opt.state_dict())
opt2 = ScimbaLBFGS(list(net.parameters()))
def closure() -> float:
"""For test.
Returns:
For test.
"""
opt2.zero_grad()
# Forward pass
output_tensor = net(input_tensor)
loss[0] = loss_func(output_tensor, target_tensor)
loss[0].backward(retain_graph=True)
return loss[0].item()
epochs = 1000
for epoch in range(epochs):
opt2.optimizer_step(closure)
if math.isinf(loss[0].item()) or math.isnan(loss[0].item()):
loss[0] = torch.tensor(best_loss)
net.load_state_dict(best_net)
if loss[0].item() < best_loss:
best_loss = loss[0].item()
best_net = copy.deepcopy(net.state_dict())
opt2.update_best_optimizer()
if epoch % 100 == 0:
print("epoch: ", epoch, "loss: ", loss[0].item())
net.load_state_dict(best_net)
closure()
print("loss after training: ", loss[0].item())
print("net( torch.ones( 10 ) ) : ", net(torch.ones(10)))