Skip to content

4. Losses

4.1 elbo

This module contains the code for the Losses.

4.1.1 ELBOLoss(loss_function, num_samples=1, kl_weight=0.001)

Computes the Evidence Lower Bound (ELBO) loss, combining a reconstruction loss and KL divergence.

Initializes the ELBO loss with specified reconstruction loss function, sample count, and KL weight.

Parameters:

Name Type Description Default
loss_function Module

Loss function for computing reconstruction loss.

required
num_samples int

Number of samples for Monte Carlo approximation.

1
kl_weight float

Scaling factor for the KL divergence component.

0.001
Source code in illia/losses/torch/elbo.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(
    self,
    loss_function: torch.nn.Module,
    num_samples: int = 1,
    kl_weight: float = 1e-3,
) -> None:
    """
    Initializes the ELBO loss with specified reconstruction loss
    function, sample count, and KL weight.

    Args:
        loss_function: Loss function for computing reconstruction
            loss.
        num_samples: Number of samples for Monte Carlo
            approximation.
        kl_weight: Scaling factor for the KL divergence component.
    """

    # Call super class constructor
    super().__init__()

    self.loss_function = loss_function
    self.num_samples = num_samples
    self.kl_weight = kl_weight
    self.kl_loss = KLDivergenceLoss(weight=kl_weight)

4.1.1.1 forward(outputs, targets, model)

Computes the ELBO loss, averaging over multiple samples.

Parameters:

Name Type Description Default
outputs Tensor

Predicted values from the model.

required
targets Tensor

True target values.

required
model Module

PyTorch model containing Bayesian modules.

required

Returns:

Type Description
Tensor

Average ELBO loss across samples.

Source code in illia/losses/torch/elbo.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def forward(
    self, outputs: torch.Tensor, targets: torch.Tensor, model: torch.nn.Module
) -> torch.Tensor:
    """
    Computes the ELBO loss, averaging over multiple samples.

    Args:
        outputs: Predicted values from the model.
        targets: True target values.
        model: PyTorch model containing Bayesian modules.

    Returns:
        Average ELBO loss across samples.
    """

    loss_value = torch.tensor(
        0, device=next(model.parameters()).device, dtype=torch.float32
    )
    for _ in range(self.num_samples):
        loss_value += self.loss_function(outputs, targets) + self.kl_loss(model)

    loss_value /= self.num_samples

    return loss_value

4.1.2 KLDivergenceLoss(reduction='mean', weight=1.0)

Computes the KL divergence loss for Bayesian modules within a model.

Initializes the KL Divergence Loss with specified reduction method and weight.

Parameters:

Name Type Description Default
reduction Literal['mean']

Method to reduce the loss, currently only "mean" is supported.

'mean'
weight float

Scaling factor for the KL divergence loss.

1.0
Source code in illia/losses/torch/elbo.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def __init__(
    self, reduction: Literal["mean"] = "mean", weight: float = 1.0
) -> None:
    """
    Initializes the KL Divergence Loss with specified reduction
    method and weight.

    Args:
        reduction: Method to reduce the loss, currently only "mean"
            is supported.
        weight: Scaling factor for the KL divergence loss.
    """

    # call super class constructor
    super().__init__()

    # Set parameters
    self.reduction = reduction
    self.weight = weight

4.1.2.1 forward(model)

Computes the KL divergence loss across all Bayesian modules in the model.

Parameters:

Name Type Description Default
model Module

PyTorch model containing Bayesian modules.

required

Returns:

Type Description
Tensor

KL divergence cost scaled by the specified weight.

Source code in illia/losses/torch/elbo.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def forward(self, model: torch.nn.Module) -> torch.Tensor:
    """
    Computes the KL divergence loss across all Bayesian modules in
    the model.

    Args:
        model: PyTorch model containing Bayesian modules.

    Returns:
        KL divergence cost scaled by the specified weight.
    """

    # Get device and dtype
    parameter: torch.nn.Parameter = next(model.parameters())
    device: torch.device = parameter.device
    dtype = parameter.dtype

    # Init kl cost and params
    kl_global_cost: torch.Tensor = torch.tensor(0, device=device, dtype=dtype)
    num_params_global: int = 0

    # Iter over modules
    for module in model.modules():
        if isinstance(module, BayesianModule):
            kl_cost, num_params = module.kl_cost()
            kl_global_cost += kl_cost
            num_params_global += num_params

    # Average by the number of parameters
    kl_global_cost /= num_params
    kl_global_cost *= self.weight

    return kl_global_cost