scimba_torch.optimizers.losses

A module to handle losses.

Examples: Loss usage

import torch

from scimba_torch.optimizers.losses import GenericLosses
from scimba_torch.optimizers.optimizers_data import OptimizersData

def loss_func(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return torch.sum(torch.abs(b - a))

# create a GenericLosses with only a residual
losses = GenericLosses()
print(losses)

# create a GenericLosses with a residual and a bc with custom function
losses2 = GenericLosses(
    [("residual", torch.nn.MSELoss(), 0.5), ("bc", loss_func, 0.8)]
)
print(losses2)

# eval losses
print(
    "eval losses2[residual]: ",
    losses2("residual", torch.ones((2, 3)) + 1e-6, torch.ones((2, 3))),
)
print(
    "eval losses2[bc]: ",
    losses2("bc", torch.ones((2, 3)) + 1e-6, torch.ones((2, 3))),
)

# eval and update losses
losses2.call_and_update(
    "residual",
    torch.ones((2, 3), requires_grad=True) + 1e-6,
    torch.ones((2, 3), requires_grad=True),
)
losses2.call_and_update("bc", torch.ones((2, 3)) + 1e-6, torch.ones((2, 3)))
print("losses2 residual: ", losses2.get_loss("residual"))
print("losses2 bc : ", losses2.get_loss("bc"))

# update histories
losses2.update_histories()
print("losses2 residual loss_history: ", losses2.get_history("residual"))
print("losses2 bc loss_history: ", losses2.get_history("bc"))

# test compute full loss
opt = OptimizerData()
epo = 0
loss = losses2.compute_full_loss(opt, epo)
print("losses2 full loss: ", loss, ", ", losses2.get_full_loss())

# errors
try:
    losses3 = GenericLosses([])
except ValueError as error:
    print(error)

try:
    losses2.call_and_update(
        "test",
        torch.ones((2, 3)) + 1e-6,
        torch.ones((2, 3)),
    )
except KeyError as error:
    print(error)

# test adaptive weights
class SimpleNN(torch.nn.Module):
    '''For test.'''

    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = torch.nn.Linear(10, 10)
        self.fc2 = torch.nn.Linear(10, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

opt_1 = {
    "name": "adam",
    "optimizer_args": {"lr": 0.01},
    "scheduler_args": {"gamma": 0.9, "step_size": 10},
}

net = SimpleNN()
optimizer_data = OptimizerData(
    opt_1,
    {
        "name": "lbfgs",
        "switch_at_epoch_ratio": 0.7,
        "switch_at_plateau": [50, 10]
    },
)
optimizer_data.activate_first_optimizer(list(net.parameters()))

# Create dummy input and target tensors
input_tensor = torch.randn(5, 10)  # Batch of 5 samples, each of size 10
target_tensor = torch.randn(5, 1)  # Target tensor with batch size 5

# Forward pass
output_tensor = net(input_tensor)

losses3 = GenericLosses(
    [("residual", torch.nn.MSELoss(), 0.5), ("bc", loss_func, 0.8)],
    adaptive_weights="annealing",
)

losses3.call_and_update("residual", output_tensor, target_tensor)
losses3.call_and_update("bc", output_tensor, target_tensor)

loss = losses3.compute_full_loss(optimizer_data, 10)

print("losses3 full loss: ", loss, ", ", losses3.get_full_loss())

print("losses3 residual coeff        : ", losses3.get_coeff("residual"))
print("losses3 residual coeff history: ", losses3.get_coeff_history("residual"))
print("losses3 bc       coeff        : ", losses3.get_coeff("bc"))
print("losses3 bc       coeff history: ", losses3.get_coeff_history("bc"))

losses3.update_histories()

input_tensor = torch.randn(5, 10)  # Batch of 5 samples, each of size 10
# Forward pass
output_tensor = net(input_tensor)
losses3.call_and_update("residual", output_tensor, target_tensor)
losses3.call_and_update("bc", output_tensor, target_tensor)
losses3.compute_full_loss(optimizer_data, 10)
losses3.update_histories()

print("losses3 dict_for_save: ", losses3.dict_for_save())

Classes

DataLoss(args, vals[, loss_function])

A class to handle data losses.

GenericLoss(loss_function, coeff)

A class for a loss with a coefficient and history.

GenericLosses([losses])

A class to handle several losses: residual, boundary conditions, etc.

MassLoss([size_average, reduce, reduction])

Custom loss function for the difference in mass between input and target tensors.

class MassLoss(size_average=None, reduce=None, reduction='mean')[source]

Bases: _Loss

Custom loss function for the difference in mass between input and target tensors.

This loss returns either the mean or sum of the element-wise difference between input and target, depending on the reduction parameter.

Parameters:
  • size_average (bool | None) – Deprecated (unused). Included for API compatibility.

  • reduce (bool | None) – Deprecated (unused). Included for API compatibility.

  • reduction (str) – Specifies the reduction to apply to the output. Must be ‘mean’ (default) or ‘sum’.

Example

>>> loss = MassLoss(reduction='sum')
>>> input = torch.tensor([1.0, 2.0, 3.0])
>>> target = torch.tensor([0.5, 1.5, 2.5])
>>> output = loss(input, target)
>>> print(output)
tensor(1.5)
forward(input, target)[source]

Computes the mass loss between input and target tensors.

Parameters:
  • input (Tensor) – The predicted values.

  • target (Tensor) – The ground-truth values.

Return type:

Tensor

Returns:

The scalar loss value (mean or sum of differences).

class DataLoss(args, vals, loss_function=MSELoss())[source]

Bases: object

A class to handle data losses.

Parameters:
  • args (tuple[Tensor, ...]) – the points

  • vals (Tensor) – the values at points

  • loss_function (Union[Callable[[Tensor, Tensor], Tensor], _Loss]) – the loss_function

class GenericLoss(loss_function, coeff)[source]

Bases: object

A class for a loss with a coefficient and history.

Parameters:
  • loss_function (Union[Callable[[Tensor, Tensor], Tensor], _Loss]) – The loss function.

  • coeff (float) – A coefficient that scales the computed loss value.

func

The loss function.

coeff

The coeff.

coeff_history: list[float]

The history of coeffs.

loss

The current loss value.

weighted_loss

The current weighted loss value.

loss_history: list[float]

The history of losses.

get_loss()[source]

Returns the current loss value.

Return type:

Tensor

Returns:

The current loss value.

get_weighted_loss()[source]

Returns the current weighted loss value.

Return type:

Tensor

Returns:

The current weighted loss value (coeff * loss).

get_loss_history()[source]

Returns the history of computed loss values.

Return type:

list[float]

Returns:

A list of loss values (in float).

get_coeff()[source]

Returns the current coefficient value.

Return type:

float

Returns:

The current coefficient value.

get_coeff_history()[source]

Returns the history of coefficient values.

Return type:

list[float]

Returns:

A list of coefficient values representing the history of coefficients used.

init_loss()[source]

Resets the loss and weighted loss to infinity.

Return type:

None

update_loss(value)[source]

Updates the current loss value and recalculates the weighted loss.

Parameters:

value (Tensor) – The new loss value to be set.

update_history(loss_factor=1.0)[source]

Appends the current loss (optionally scaled by a factor) to the loss history.

Parameters:

loss_factor (float) – A factor by which to scale the loss before adding it to the history. Defaults to 1.0.

Return type:

None

set_history(history)[source]

Sets the history of loss values to the provided list of floats.

Parameters:

history (list[float]) – A list of float values representing the new loss history.

Return type:

None

update_coeff(coeff)[source]

Updates the coefficient value and recalculates the weighted loss.

Parameters:

coeff (float) – The new coefficient value to be set.

Return type:

None

call_and_update(a, b)[source]

Calls the loss function, updates the loss, and returns the updated loss.

Parameters:
  • a (Tensor) – The first input tensor.

  • b (Tensor) – The second input tensor.

Return type:

Tensor

Returns:

The updated loss value.

class GenericLosses(losses=None, **kwargs)[source]

Bases: object

A class to handle several losses: residual, boundary conditions, etc.

A class that manages multiple instances of GenericLoss and

computes the full loss as a combination of all individual losses.

Parameters:
  • losses (Sequence[tuple[str, Union[Callable[[Tensor, Tensor], Tensor], _Loss], float | int]] | None) – A list of tuples; each tuple contains a loss name, a callable loss function, and a coefficient. Default is None.

  • **kwargs

    Additional keyword arguments. “adaptive_weights”: The method for adaptive weighting of losses. currently only “annealing” is supported.

    ”principal_weights”: the name of the reference loss for adapting weights.

    ”epochs_adapt”: the number of epochs between adaptive weight updates.

    ”alpha_lr_annealing”: the learning rate annealing factor for adaptive weighting.

Raises:
  • ValueError – If the input list is empty.

  • TypeError – If the input list contains elements with incorrect types.

losses_dict: dict[str, GenericLoss]

A dictionary mapping loss names to GenericLoss instances.

loss: Tensor

The current full loss value, which is the sum of all weighted losses.

loss_history: list[float]

A list storing the history of computed full loss values.

adaptive_weights: str | None

The method for adaptive weighting of losses. Default is None.

principal_weights: str | None

The name of the principal loss used for adaptive weighting.

epochs_adapt: int

The number of epochs between adaptive weight updates. Default is 10.

alpha_lr_annealing: float

“The learning rate annealing factor for adaptive weighting. Default is 0.9.

get_full_loss()[source]

Returns the current full loss value.

Return type:

Tensor

Returns:

The current full loss value.

get_loss(key)[source]

Returns the current loss value for a specific loss function.

Parameters:

key (str) – The name of the loss function.

Return type:

Tensor

Returns:

The current loss value for the specified loss function.

Raises:

KeyError – If the key is not found in the losses dictionary.

get_history(key)[source]

Returns the history of computed loss values for a specific loss function.

Parameters:

key (str) – The name of the loss function.

Return type:

list[float]

Returns:

The history of computed loss values for the specified loss function.

Raises:

KeyError – If the key is not found in the losses dictionary.

get_coeff(key)[source]

Returns the current coefficient value for a specific loss function.

Parameters:

key (str) – The name of the loss function.

Return type:

float

Returns:

The current coefficient value for the specified loss function.

Raises:

KeyError – If the key is not found in the losses dictionary.

get_coeff_history(key)[source]

Returns the history of coefficient values for a specific loss function.

Parameters:

key (str) – The name of the loss function.

Return type:

list[float]

Returns:

The history of coefficient values for the specified loss function.

Raises:

KeyError – If the key is not found in the losses dictionary.

init_losses()[source]

Resets all loss values to infinity.

Return type:

None

init_loss(key)[source]

Resets the loss value for a specific loss function to infinity.

Parameters:

key (str) – The name of the loss function.

Raises:

KeyError – If the key is not found in the losses dictionary.

Return type:

None

update_loss(key, value)[source]

Updates the loss value for a specific loss function.

Parameters:
  • key (str) – The name of the loss function.

  • value (Tensor) – The new loss value to be set.

Raises:

KeyError – If the key is not found in the losses dictionary.

Return type:

None

update_histories(loss_factor=1.0)[source]

Appends the current loss (optionally scaled by a factor) to the loss history.

Parameters:

loss_factor (float) – A factor by which to scale the loss before adding it to the history. Defaults to 1.0.

Return type:

None

update_coeff(key, value)[source]

Updates the coefficient value for a specific loss function.

Parameters:
  • key (str) – The name of the loss function.

  • value (float) – The new coefficient value to be set.

Raises:

KeyError – If the key is not found in the losses dictionary.

Return type:

None

call_and_update(key, a, b)[source]

Calls the loss function, updates the loss, and returns the updated loss.

Parameters:
  • key (str) – The name of the loss function.

  • a (Tensor) – The first input tensor.

  • b (Tensor) – The second input tensor.

Return type:

Tensor

Returns:

The updated loss value.

Raises:

KeyError – If the key is not found in the losses dictionary.

compute_all_losses(left, right, update=True)[source]

Computes all losses.

Returns the combination of all the losses, possibly updates the loss values.

Parameters:
  • left (tuple[Tensor, ...]) – The left tensors.

  • right (tuple[Tensor, ...]) – The right tensors.

  • update (bool) – Whether to update the current loss.

Returns:

The computed full loss value.

Return type:

torch.Tensor

Raises:

ValueError – when left and right do not have the same length or the length of left (and right) is not a divisor of the number of losses.

compute_full_loss_without_updating(left, right)[source]

Computes the full loss without updating the loss values.

Parameters:
  • left (tuple[Tensor, ...]) – The left tensors.

  • right (tuple[Tensor, ...]) – The right tensors.

Return type:

Tensor

Returns:

The computed full loss value.

compute_full_loss(optimizers, epoch)[source]

Computes the full loss as the combination of all the losses.

Parameters:
  • optimizers (OptimizerData) – The optimizer data object.

  • epoch (int) – The current epoch.

Return type:

Tensor

Returns:

The computed full loss value.

Raises:

ValueError – when adaptive_weights is not recognized.

dict_for_save()[source]

Returns a dictionary of best loss values for saving.

Return type:

dict[str, Tensor | list[float]]

Returns:

A dictionary containing the best loss value and loss history.

try_to_load(checkpoint, string)[source]

Tries to load a value from the checkpoint.

Parameters:
  • checkpoint (dict) – The checkpoint dictionary.

  • string (str) – The key to look for in the checkpoint.

Return type:

Any

Returns:

The loaded value if found, otherwise None.

load_from_dict(checkpoint)[source]

Loads the loss history from a checkpoint.

Parameters:

checkpoint (dict) – The checkpoint dictionary.

Return type:

None

plot(ax, **kwargs)[source]

Plots the loss history on the given axis.

Parameters:
  • ax (Axes) – The axis on which to plot the loss history.

  • **kwargs – Additional keyword arguments.

Return type:

Axes

Returns:

The axis with the plotted loss history.