scimba_torch.optimizers.line_search¶
Linesearch functions.
Functions
|
Line search algorithm based on the Armijo condition. |
|
Line search algorithm based on the Armijo condition. |
|
Line search algorithm based on a logarithmic grid. |
- logarithmic_grid_line_search(loss, theta, dsearch, m=10, interval=[0.0, 1.0], log_basis=2.0, **kwargs)[source]¶
Line search algorithm based on a logarithmic grid.
- Parameters:
loss (
Callable[[Tensor],Tensor]) – The loss function.theta (
Tensor) – The current parameters of the loss.dsearch (
Tensor) – The search direction.m (
int) – The number of points in the logarithmic grid.interval (
list[float]) – The interval of the logarithmic grid.log_basis (
float) – The logarithmic basis to generate the grid.**kwargs – Arbitrary keyword arguments.
- Return type:
Tensor- Returns:
An eta minimizing the loss along the search direction from theta.
- Raises:
ValueError – when log_basis <= 0.
- backtracking_armijo_line_search_with_loss_theta_grad_loss_theta(loss, theta, loss_theta, grad_loss_theta, dsearch, alpha=0.01, beta=0.5, n_step_max=10, **kwargs)[source]¶
Line search algorithm based on the Armijo condition.
- Parameters:
loss (
Callable[[Tensor],Tensor]) – The loss function.theta (
Tensor) – The current parameters of the loss.loss_theta (
Tensor) – The loss at theta.grad_loss_theta (
Tensor) – The gradient of the loss at theta.dsearch (
Tensor) – The search direction.alpha (
float) – The Armijo condition parameter.beta (
float) – The Armijo condition parameter.n_step_max (
int) – The maximum number of steps in the backtracking algorithm.**kwargs – Arbitrary keyword arguments.
- Return type:
Tensor- Returns:
An eta minimizing the loss along the search direction from theta.
- backtracking_armijo_line_search(loss, grad_loss, theta, dsearch, alpha=0.1, beta=0.5, n_step_max=10, **kwargs)[source]¶
Line search algorithm based on the Armijo condition.
- Parameters:
loss (
Callable[[Tensor],Tensor]) – The loss function.grad_loss (
Callable[[Tensor],Tensor]) – The gradient function of the loss function.theta (
Tensor) – The current parameters of the loss.dsearch (
Tensor) – The search direction.alpha (
float) – The Armijo condition parameter.beta (
float) – The Armijo condition parameter.n_step_max (
int) – The maximum number of steps in the backtracking algorithm.**kwargs – Arbitrary keyword arguments.
- Return type:
Tensor- Returns:
An eta minimizing the loss along the search direction from theta.