Skip to content

1. Distributions

Base class for building distribution modules using Flax and JAX.

Provides a standardized interface for sampling, computing log probabilities, and reporting the number of parameters in custom probabilistic layers.

Notes

This class is abstract and should not be instantiated directly. Subclasses must implement all abstract methods to specify distribution behavior.

1.1 DistributionModule

Abstract base for Flax-based probabilistic distribution modules.

Defines a required interface for sampling, computing log-probabilities, and retrieving parameter counts. Subclasses must implement all abstract methods to provide specific distribution logic.

Notes

Avoid direct instantiation, this serves as a blueprint for derived classes.

1.1.1 num_params abstractmethod property

Returns the total number of learnable parameters in the distribution.

Returns:

Type Description
int

Integer representing the total number of learnable parameters.

1.1.2 log_prob(x=None) abstractmethod

Computes the log-probability of an input sample. If no sample is provided, a new one is drawn internally from the current distribution before computing the log-probability.

Parameters:

Name Type Description Default
x Optional[Array]

Optional sample tensor to evaluate.

None

Returns:

Type Description
Array

Scalar tensor representing the computed log-probability.

Notes

This method supports both user-supplied samples and internally generated ones for convenience when evaluating likelihoods.

Source code in illia/distributions/jax/base.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@abstractmethod
def log_prob(self, x: Optional[jax.Array] = None) -> jax.Array:
    """
    Computes the log-probability of an input sample.
    If no sample is provided, a new one is drawn internally from the
    current distribution before computing the log-probability.

    Args:
        x: Optional sample tensor to evaluate.

    Returns:
        Scalar tensor representing the computed log-probability.

    Notes:
        This method supports both user-supplied samples and internally
        generated ones for convenience when evaluating likelihoods.
    """

1.1.3 sample(rngs=nnx.Rngs(0)) abstractmethod

Generates and returns a sample from the underlying distribution.

Parameters:

Name Type Description Default
rngs Rngs

RNG container used for sampling.

Rngs(0)

Returns:

Type Description
Array

Sample array matching the shape and structure defined by

Array

the distribution parameters.

Source code in illia/distributions/jax/base.py
37
38
39
40
41
42
43
44
45
46
47
48
@abstractmethod
def sample(self, rngs: Rngs = nnx.Rngs(0)) -> jax.Array:
    """
    Generates and returns a sample from the underlying distribution.

    Args:
        rngs: RNG container used for sampling.

    Returns:
        Sample array matching the shape and structure defined by
        the distribution parameters.
    """

Defines a Gaussian (Normal) distribution using Flax with trainable mean and standard deviation parameters. Includes methods for sampling from the distribution and computing log-probabilities of given inputs.

1.2 GaussianDistribution(shape, mu_prior=0.0, std_prior=0.1, mu_init=0.0, rho_init=-7.0, rngs=nnx.Rngs(0), **kwargs)

Learnable Gaussian distribution using Flax.

Represents a diagonal Gaussian distribution with trainable mean and standard deviation parameters. The standard deviation is derived from rho using a softplus transformation to ensure positivity.

Notes

Assumes a diagonal covariance matrix. KL divergence between distributions can be computed using log-probability differences obtained from log_prob.

Initializes the GaussianDistribution module.

Parameters:

Name Type Description Default
shape tuple[int, ...]

Shape of the learnable parameters.

required
mu_prior float

Mean of the Gaussian prior.

0.0
std_prior float

Standard deviation of the prior.

0.1
mu_init float

Initial value for the learnable mean.

0.0
rho_init float

Initial value for the learnable rho parameter.

-7.0
rngs Rngs

RNG container for parameter initialization.

Rngs(0)
**kwargs Any

Additional arguments passed to the base class.

{}

Returns:

Type Description
None

None.

Source code in illia/distributions/jax/gaussian.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(
    self,
    shape: tuple[int, ...],
    mu_prior: float = 0.0,
    std_prior: float = 0.1,
    mu_init: float = 0.0,
    rho_init: float = -7.0,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Initializes the GaussianDistribution module.

    Args:
        shape: Shape of the learnable parameters.
        mu_prior: Mean of the Gaussian prior.
        std_prior: Standard deviation of the prior.
        mu_init: Initial value for the learnable mean.
        rho_init: Initial value for the learnable rho parameter.
        rngs: RNG container for parameter initialization.
        **kwargs: Additional arguments passed to the base class.

    Returns:
        None.
    """

    # Call super-class constructor
    super().__init__(**kwargs)

    # Set attributes
    self.shape = shape
    self.mu_prior = mu_prior
    self.std_prior = std_prior
    self.mu_init = mu_init
    self.rho_init = rho_init

    # Define initial mu and rho
    self.mu = nnx.Param(
        self.mu_init + 0.1 * jax.random.normal(rngs.params(), self.shape)
    )
    self.rho = nnx.Param(
        self.rho_init + 0.1 * jax.random.normal(rngs.params(), self.shape)
    )

1.2.1 num_params property

Returns the total number of learnable parameters in the distribution.

Returns:

Type Description
int

Integer representing the total number of learnable parameters.

1.2.2 __call__()

Performs a forward pass by sampling from the distribution.

Returns:

Type Description
Array

A sample from the distribution.

Source code in illia/distributions/jax/gaussian.py
156
157
158
159
160
161
162
163
164
def __call__(self) -> jax.Array:
    """
    Performs a forward pass by sampling from the distribution.

    Returns:
        A sample from the distribution.
    """

    return self.sample()

1.2.3 log_prob(x=None)

Computes the log-probability of an input sample. If no sample is provided, a new one is drawn internally from the current distribution before computing the log-probability.

Parameters:

Name Type Description Default
x Optional[Array]

Optional sample tensor to evaluate.

None

Returns:

Type Description
Array

Scalar tensor representing the computed log-probability.

Notes

This method supports both user-supplied samples and internally generated ones for convenience when evaluating likelihoods.

Source code in illia/distributions/jax/gaussian.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def log_prob(self, x: Optional[jax.Array] = None) -> jax.Array:
    """
    Computes the log-probability of an input sample.
    If no sample is provided, a new one is drawn internally from the
    current distribution before computing the log-probability.

    Args:
        x: Optional sample tensor to evaluate.

    Returns:
        Scalar tensor representing the computed log-probability.

    Notes:
        This method supports both user-supplied samples and internally
        generated ones for convenience when evaluating likelihoods.
    """

    # Sample if x is None
    if x is None:
        x = self.sample()

    # Define pi variable
    pi: jax.Array = jnp.acos(jnp.zeros(1)) * 2

    # Compute log priors
    log_prior = (
        -jnp.log(jnp.sqrt(2 * pi))
        - jnp.log(self.std_prior)
        - (((x - self.mu_prior) ** 2) / (2 * self.std_prior**2))
        - 0.5
    )

    # Compute sigma
    sigma: jax.Array = jnp.log1p(jnp.exp(jnp.asarray(self.rho)))

    # Compute log posteriors
    log_posteriors = (
        -jnp.log(jnp.sqrt(2 * pi))
        - jnp.log(sigma)
        - (((x - self.mu) ** 2) / (2 * sigma**2))
        - 0.5
    )

    # Compute final log probs
    log_probs = log_posteriors.sum() - log_prior.sum()

    return log_probs

1.2.4 sample(rngs=nnx.Rngs(0))

Generates and returns a sample from the Gaussian distribution.

Parameters:

Name Type Description Default
rngs Rngs

RNG container used for sampling.

Rngs(0)

Returns:

Type Description
Array

Sample array matching the shape and structure defined by

Array

the distribution parameters.

Source code in illia/distributions/jax/gaussian.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def sample(self, rngs: Rngs = nnx.Rngs(0)) -> jax.Array:
    """
    Generates and returns a sample from the Gaussian distribution.

    Args:
        rngs: RNG container used for sampling.

    Returns:
        Sample array matching the shape and structure defined by
        the distribution parameters.
    """

    # Compute epsilon and sigma
    eps: jax.Array = jax.random.normal(rngs.params(), self.rho.shape)
    sigma: jax.Array = jnp.log1p(jnp.exp(jnp.asarray(self.rho)))

    return self.mu + sigma * eps