Skip to content

2. Losses

This module implements the Kullback-Leibler (KL) divergence loss for Bayesian neural networks in Jax.

2.1 KLDivergenceLoss(reduction='mean', weight=1.0, **kwargs)

Computes the Kullback-Leibler divergence loss across all Bayesian modules.

Initializes the Kullback-Leibler divergence loss computation.

Parameters:

Name Type Description Default
reduction Literal['mean']

Reduction method for the loss.

'mean'
weight float

Scaling factor applied to the total KL loss.

1.0

Returns:

Type Description
None

None.

Source code in illia/losses/jax/kl.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    reduction: Literal["mean"] = "mean",
    weight: float = 1.0,
    **kwargs: Any,
) -> None:
    """
    Initializes the Kullback-Leibler divergence loss computation.

    Args:
        reduction: Reduction method for the loss.
        weight: Scaling factor applied to the total KL loss.

    Returns:
        None.
    """

    # Call super class constructor
    super().__init__(**kwargs)

    # Set attributes
    self.reduction = reduction
    self.weight = weight

2.1.1 __call__(model)

Computes Kullback-Leibler divergence for all Bayesian modules in the model.

Parameters:

Name Type Description Default
model Module

Model containing Bayesian submodules.

required

Returns:

Type Description
Array

Scaled Kullback-Leibler divergence loss as a scalar array.

Source code in illia/losses/jax/kl.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __call__(self, model: nnx.Module) -> jax.Array:
    """
    Computes Kullback-Leibler divergence for all Bayesian
    modules in the model.

    Args:
        model: Model containing Bayesian submodules.

    Returns:
        Scaled Kullback-Leibler divergence loss as a scalar array.
    """

    # Init kl cost and params
    kl_global_cost: jax.Array = jnp.array(0.0)
    num_params_global: int = 0

    # Iter over modules
    for _, module in model.iter_modules():
        if isinstance(module, BayesianModule):
            kl_cost, num_params = module.kl_cost()
            kl_global_cost += kl_cost
            num_params_global += num_params

    # Average by the number of parameters
    kl_global_cost /= num_params
    kl_global_cost *= self.weight

    return kl_global_cost

This module implements the Evidence Lower Bound (ELBO) loss for Bayesian neural networks in Jax.

2.2 ELBOLoss(loss_function, num_samples=1, kl_weight=0.001, **kwargs)

Computes the Evidence Lower Bound (ELBO) loss function for Bayesian neural networks.

This combines a reconstruction loss and a KL divergence term, estimated using Monte Carlo sampling.

Initializes the ELBO loss with sampling and KL scaling.

Parameters:

Name Type Description Default
loss_function Callable[[Array, Array], Array]

Module for computing reconstruction loss.

required
num_samples int

Number of MC samples for estimation.

1
kl_weight float

Weight applied to the KL loss.

0.001

Returns:

Type Description
None

None.

Source code in illia/losses/jax/elbo.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(
    self,
    loss_function: Callable[[jax.Array, jax.Array], jax.Array],
    num_samples: int = 1,
    kl_weight: float = 1e-3,
    **kwargs: Any,
) -> None:
    """
    Initializes the ELBO loss with sampling and KL scaling.

    Args:
        loss_function: Module for computing reconstruction loss.
        num_samples: Number of MC samples for estimation.
        kl_weight: Weight applied to the KL loss.

    Returns:
        None.
    """

    # Call super class constructor
    super().__init__(**kwargs)

    # Set attributes
    self.loss_function = loss_function
    self.num_samples = num_samples
    self.kl_weight = kl_weight
    self.kl_loss = KLDivergenceLoss(weight=kl_weight)

2.2.1 __call__(outputs, targets, model)

Compute the ELBO loss using Monte Carlo sampling and KL regularization.

Parameters:

Name Type Description Default
outputs Array

Predictions generated by the model.

required
targets Array

Ground truth values for training.

required
model Module

Model containing Bayesian layers.

required

Returns:

Type Description
Array

Scalar representing the average ELBO loss.

Source code in illia/losses/jax/elbo.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __call__(
    self, outputs: jax.Array, targets: jax.Array, model: nnx.Module
) -> jax.Array:
    """
    Compute the ELBO loss using Monte Carlo sampling and KL regularization.

    Args:
        outputs: Predictions generated by the model.
        targets: Ground truth values for training.
        model: Model containing Bayesian layers.

    Returns:
        Scalar representing the average ELBO loss.
    """

    loss_value: jax.Array = jnp.array(0.0)
    for _ in range(self.num_samples):
        loss_value += self.loss_function(outputs, targets) + self.kl_loss(model)

    loss_value /= self.num_samples

    return loss_value