1. Distributions
1.1
gaussian
This module contains the code for the Gaussian distribution.
1.1.1
GaussianDistribution(shape, mu_prior=0.0, std_prior=0.1, mu_init=0.0, rho_init=-7.0, rngs=nnx.Rngs(0))
This is the class to implement a learnable Gaussian distribution.
Constructor for GaussianDistribution.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shape
|
tuple[int, ...]
|
The shape of the parameters. |
required |
mu_prior
|
float
|
The mean prior value. |
0.0
|
std_prior
|
float
|
The standard deviation prior value. |
0.1
|
mu_init
|
float
|
The initial mean value. |
0.0
|
rho_init
|
float
|
The initial rho value, which affects the initial standard deviation. |
-7.0
|
rngs
|
Rngs
|
Nnx rng container. Defaults to nnx.Rngs(0). |
Rngs(0)
|
Source code in illia/distributions/jax/gaussian.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
|
1.1.1.1
num_params
property
Returns the number of parameters in the module.
Returns:
Type | Description |
---|---|
int
|
The number of parameters as an integer. |
1.1.1.2
__call__()
Performs the forward pass of the module.
Returns:
Type | Description |
---|---|
Array
|
A sampled JAX array. |
Source code in illia/distributions/jax/gaussian.py
77 78 79 80 81 82 83 84 85 |
|
1.1.1.3
log_prob(x=None)
Computes the log probability of a given sample.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Optional[Array]
|
An optional sampled array. If None, a sample is generated. |
None
|
Returns:
Type | Description |
---|---|
Array
|
The log probability of the sample as a JAX array. |
Source code in illia/distributions/jax/gaussian.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
|
1.1.1.4
sample(rngs=nnx.Rngs(0))
Samples from the distribution using the current parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
rngs
|
Rngs
|
Nnx rng container. Defaults to nnx.Rngs(0). |
Rngs(0)
|
Returns:
Type | Description |
---|---|
Array
|
A sampled JAX array. |
Source code in illia/distributions/jax/gaussian.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
|