1. Distributions
Base class for building distribution modules using Flax and JAX.
Provides a standardized interface for sampling, computing log probabilities, and reporting the number of parameters in custom probabilistic layers.
Notes
This class is abstract and should not be instantiated directly. Subclasses must implement all abstract methods to specify distribution behavior.
1.1
DistributionModule
Abstract base for Flax-based probabilistic distribution modules.
Defines a required interface for sampling, computing log-probabilities, and retrieving parameter counts. Subclasses must implement all abstract methods to provide specific distribution logic.
Notes
Avoid direct instantiation, this serves as a blueprint for derived classes.
1.1.1
num_params
abstractmethod
property
Returns the total number of learnable parameters in the distribution.
Returns:
Type | Description |
---|---|
int
|
Integer representing the total number of learnable parameters. |
1.1.2
log_prob(x=None)
abstractmethod
Computes the log-probability of an input sample. If no sample is provided, a new one is drawn internally from the current distribution before computing the log-probability.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Optional[Array]
|
Optional sample tensor to evaluate. |
None
|
Returns:
Type | Description |
---|---|
Array
|
Scalar tensor representing the computed log-probability. |
Notes
This method supports both user-supplied samples and internally generated ones for convenience when evaluating likelihoods.
Source code in illia/distributions/jax/base.py
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
1.1.3
sample(rngs=nnx.Rngs(0))
abstractmethod
Generates and returns a sample from the underlying distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rngs
|
Rngs
|
RNG container used for sampling. |
Rngs(0)
|
Returns:
Type | Description |
---|---|
Array
|
Sample array matching the shape and structure defined by |
Array
|
the distribution parameters. |
Source code in illia/distributions/jax/base.py
37 38 39 40 41 42 43 44 45 46 47 48 |
|
Defines a Gaussian (Normal) distribution using Flax with trainable mean and standard deviation parameters. Includes methods for sampling from the distribution and computing log-probabilities of given inputs.
1.2
GaussianDistribution(shape, mu_prior=0.0, std_prior=0.1, mu_init=0.0, rho_init=-7.0, rngs=nnx.Rngs(0), **kwargs)
Learnable Gaussian distribution using Flax.
Represents a diagonal Gaussian distribution with trainable mean and
standard deviation parameters. The standard deviation is derived from
rho
using a softplus transformation to ensure positivity.
Notes
Assumes a diagonal covariance matrix. KL divergence between
distributions can be computed using log-probability differences
obtained from log_prob
.
Initializes the GaussianDistribution module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shape
|
tuple[int, ...]
|
Shape of the learnable parameters. |
required |
mu_prior
|
float
|
Mean of the Gaussian prior. |
0.0
|
std_prior
|
float
|
Standard deviation of the prior. |
0.1
|
mu_init
|
float
|
Initial value for the learnable mean. |
0.0
|
rho_init
|
float
|
Initial value for the learnable rho parameter. |
-7.0
|
rngs
|
Rngs
|
RNG container for parameter initialization. |
Rngs(0)
|
**kwargs
|
Any
|
Additional arguments passed to the base class. |
{}
|
Returns:
Type | Description |
---|---|
None
|
None. |
Source code in illia/distributions/jax/gaussian.py
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
1.2.1
num_params
property
Returns the total number of learnable parameters in the distribution.
Returns:
Type | Description |
---|---|
int
|
Integer representing the total number of learnable parameters. |
1.2.2
__call__()
Performs a forward pass by sampling from the distribution.
Returns:
Type | Description |
---|---|
Array
|
A sample from the distribution. |
Source code in illia/distributions/jax/gaussian.py
156 157 158 159 160 161 162 163 164 |
|
1.2.3
log_prob(x=None)
Computes the log-probability of an input sample. If no sample is provided, a new one is drawn internally from the current distribution before computing the log-probability.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Optional[Array]
|
Optional sample tensor to evaluate. |
None
|
Returns:
Type | Description |
---|---|
Array
|
Scalar tensor representing the computed log-probability. |
Notes
This method supports both user-supplied samples and internally generated ones for convenience when evaluating likelihoods.
Source code in illia/distributions/jax/gaussian.py
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
|
1.2.4
sample(rngs=nnx.Rngs(0))
Generates and returns a sample from the Gaussian distribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rngs
|
Rngs
|
RNG container used for sampling. |
Rngs(0)
|
Returns:
Type | Description |
---|---|
Array
|
Sample array matching the shape and structure defined by |
Array
|
the distribution parameters. |
Source code in illia/distributions/jax/gaussian.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
|