1. Distributions
1.1
DistributionModule
Abstract base for probabilistic distribution modules in JAX. Defines the required interface for sampling, computing log-probabilities, and counting learnable parameters.
Notes
This class is abstract and cannot be instantiated directly. All abstract methods must be implemented by subclasses.
Source code in illia/distributions/jax/base.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | |
1.1.1
num_params
abstractmethod
property
Return the number of learnable parameters in the distribution.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
Total count of learnable parameters. |
1.1.2
log_prob(x=None)
abstractmethod
Compute the log-probability of a provided sample. If no sample is passed, one is drawn internally.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Optional[Array]
|
Optional sample to evaluate. If None, a new sample is drawn from the distribution. |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: Scalar log-probability value. |
Notes
Works with both user-supplied and internally drawn samples.
Source code in illia/distributions/jax/base.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | |
1.1.3
sample(rngs=nnx.Rngs(0))
abstractmethod
Draw a sample from the distribution using the given RNG.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rngs
|
Rngs
|
RNG container used for sampling. |
Rngs(0)
|
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: A sample drawn from the distribution. |
Notes
Sampling should be reproducible given the same RNG.
Source code in illia/distributions/jax/base.py
22 23 24 25 26 27 28 29 30 31 32 33 34 35 | |
1.2
GaussianDistribution
Learnable Gaussian distribution with diagonal covariance.
Represents a Gaussian with trainable mean and standard
deviation. The standard deviation is derived from rho
using a softplus transformation to ensure positivity.
Notes
Assumes diagonal covariance. KL divergence can be
estimated via log-probability differences from
log_prob.
Source code in illia/distributions/jax/gaussian.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | |
1.2.1
num_params
property
Return the number of learnable parameters in the distribution.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
Total count of learnable parameters. |
1.2.2
__call__()
Perform a forward pass by drawing a sample.
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: A sample from the distribution. |
Source code in illia/distributions/jax/gaussian.py
149 150 151 152 153 154 155 156 157 | |
1.2.3
__init__(shape, mu_prior=0.0, std_prior=0.1, mu_init=0.0, rho_init=-7.0, rngs=nnx.Rngs(0), **kwargs)
Initialize a learnable Gaussian distribution module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape
|
tuple[int, ...]
|
Shape of the learnable parameters. |
required |
mu_prior
|
float
|
Mean of the Gaussian prior. |
0.0
|
std_prior
|
float
|
Standard deviation of the prior. |
0.1
|
mu_init
|
float
|
Initial value for the learnable mean. |
0.0
|
rho_init
|
float
|
Initial value for the learnable rho. |
-7.0
|
rngs
|
Rngs
|
RNG container for parameter initialization. |
Rngs(0)
|
**kwargs
|
Any
|
Extra arguments passed to the base class. |
{}
|
Returns:
| Type | Description |
|---|---|
None
|
None |
Source code in illia/distributions/jax/gaussian.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | |
1.2.4
log_prob(x=None)
Compute the log-probability of a given sample. If no sample is provided, one is drawn internally.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Optional[Array]
|
Optional input sample to evaluate. If None, a new sample is drawn from the distribution. |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: Scalar log-probability value. |
Notes
Supports both user-supplied and internally drawn samples.
Source code in illia/distributions/jax/gaussian.py
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | |
1.2.5
sample(rngs=nnx.Rngs(0))
Draw a sample from the Gaussian distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rngs
|
Rngs
|
RNG container used for sampling. |
Rngs(0)
|
Returns:
| Type | Description |
|---|---|
Array
|
jax.Array: A sample drawn from the distribution. |
Notes
Sampling is reproducible with the same RNG.
Source code in illia/distributions/jax/gaussian.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | |