2. Losses
2.1
KLDivergenceLoss
Compute Kullback-Leibler divergence across Bayesian modules. This loss sums the KL divergence from all Bayesian layers in the model. It can be reduced by averaging and scaled by a weight factor.
Notes
Assumes the model contains submodules derived from
BayesianModule.
Source code in illia/losses/jax/kl.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | |
2.1.1
__call__(model)
Compute KL divergence for all Bayesian modules in a model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model containing Bayesian submodules. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: Weighted KL divergence loss. |
Notes
The KL loss is averaged over the number of parameters
and scaled by the weight attribute.
Source code in illia/losses/jax/kl.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | |
2.1.2
__init__(reduction='mean', weight=1.0, **kwargs)
Initialize the KL divergence loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reduction
|
Literal['mean']
|
Method used to reduce the KL loss. |
'mean'
|
weight
|
float
|
Scaling factor for the KL divergence. |
1.0
|
**kwargs
|
Any
|
Extra arguments passed to the base class. |
{}
|
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in illia/losses/jax/kl.py
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | |
2.2
ELBOLoss
Compute the Evidence Lower Bound (ELBO) loss for Bayesian networks. Combines a reconstruction loss with a KL divergence term. Monte Carlo sampling can estimate the expected reconstruction loss over stochastic layers.
Notes
The KL term is weighted by kl_weight. The model is
assumed to contain Bayesian layers compatible with
KLDivergenceLoss.
Source code in illia/losses/jax/elbo.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | |
2.2.1
__call__(outputs, targets, model)
Compute the ELBO loss with Monte Carlo sampling and KL regularization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
outputs
|
Array
|
Predictions generated by the model. |
required |
targets
|
Array
|
Ground truth values for training. |
required |
model
|
Module
|
Model containing Bayesian layers. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: Scalar ELBO loss averaged over samples. |
Notes
The loss is averaged over num_samples Monte Carlo
draws.
Source code in illia/losses/jax/elbo.py
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | |
2.2.2
__init__(loss_function, num_samples=1, kl_weight=0.001, **kwargs)
Initialize the ELBO loss with reconstruction and KL components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_function
|
Callable[[Array, Array], Array]
|
Function to compute reconstruction loss. |
required |
num_samples
|
int
|
Number of Monte Carlo samples used for estimation. |
1
|
kl_weight
|
float
|
Weight applied to the KL divergence term. |
0.001
|
**kwargs
|
Any
|
Extra arguments passed to the base class. |
{}
|
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in illia/losses/jax/elbo.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | |