1. Skip to content

1. Distributions

1.1 DistributionModule

Abstract base for probabilistic distribution modules in JAX. Defines the required interface for sampling, computing log-probabilities, and counting learnable parameters.

Notes

This class is abstract and cannot be instantiated directly. All abstract methods must be implemented by subclasses.

Source code in illia/distributions/jax/base.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class DistributionModule(nnx.Module, ABC):
    """
    Abstract base for probabilistic distribution modules in JAX.
    Defines the required interface for sampling, computing
    log-probabilities, and counting learnable parameters.

    Notes:
        This class is abstract and cannot be instantiated directly.
        All abstract methods must be implemented by subclasses.
    """

    @abstractmethod
    def sample(self, rngs: Rngs = nnx.Rngs(0)) -> jax.Array:
        """
        Draw a sample from the distribution using the given RNG.

        Args:
            rngs: RNG container used for sampling.

        Returns:
            jax.Array: A sample drawn from the distribution.

        Notes:
            Sampling should be reproducible given the same RNG.
        """

    @abstractmethod
    def log_prob(self, x: Optional[jax.Array] = None) -> jax.Array:
        """
        Compute the log-probability of a provided sample. If no
        sample is passed, one is drawn internally.

        Args:
            x: Optional sample to evaluate. If None, a new sample is
                drawn from the distribution.

        Returns:
            jax.Array: Scalar log-probability value.

        Notes:
            Works with both user-supplied and internally drawn
            samples.
        """

    @property
    @abstractmethod
    def num_params(self) -> int:
        """
        Return the number of learnable parameters in the
        distribution.

        Returns:
            int: Total count of learnable parameters.
        """

1.1.1 num_params abstractmethod property

Return the number of learnable parameters in the distribution.

Returns:

Name Type Description
int int

Total count of learnable parameters.

1.1.2 log_prob(x=None) abstractmethod

Compute the log-probability of a provided sample. If no sample is passed, one is drawn internally.

Parameters:

Name Type Description Default
x Optional[Array]

Optional sample to evaluate. If None, a new sample is drawn from the distribution.

None

Returns:

Type Description
Array

jax.Array: Scalar log-probability value.

Notes

Works with both user-supplied and internally drawn samples.

Source code in illia/distributions/jax/base.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@abstractmethod
def log_prob(self, x: Optional[jax.Array] = None) -> jax.Array:
    """
    Compute the log-probability of a provided sample. If no
    sample is passed, one is drawn internally.

    Args:
        x: Optional sample to evaluate. If None, a new sample is
            drawn from the distribution.

    Returns:
        jax.Array: Scalar log-probability value.

    Notes:
        Works with both user-supplied and internally drawn
        samples.
    """

1.1.3 sample(rngs=nnx.Rngs(0)) abstractmethod

Draw a sample from the distribution using the given RNG.

Parameters:

Name Type Description Default
rngs Rngs

RNG container used for sampling.

Rngs(0)

Returns:

Type Description
Array

jax.Array: A sample drawn from the distribution.

Notes

Sampling should be reproducible given the same RNG.

Source code in illia/distributions/jax/base.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
@abstractmethod
def sample(self, rngs: Rngs = nnx.Rngs(0)) -> jax.Array:
    """
    Draw a sample from the distribution using the given RNG.

    Args:
        rngs: RNG container used for sampling.

    Returns:
        jax.Array: A sample drawn from the distribution.

    Notes:
        Sampling should be reproducible given the same RNG.
    """

1.2 GaussianDistribution

Learnable Gaussian distribution with diagonal covariance. Represents a Gaussian with trainable mean and standard deviation. The standard deviation is derived from rho using a softplus transformation to ensure positivity.

Notes

Assumes diagonal covariance. KL divergence can be estimated via log-probability differences from log_prob.

Source code in illia/distributions/jax/gaussian.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class GaussianDistribution(DistributionModule):
    """
    Learnable Gaussian distribution with diagonal covariance.
    Represents a Gaussian with trainable mean and standard
    deviation. The standard deviation is derived from `rho`
    using a softplus transformation to ensure positivity.

    Notes:
        Assumes diagonal covariance. KL divergence can be
        estimated via log-probability differences from
        `log_prob`.
    """

    def __init__(
        self,
        shape: tuple[int, ...],
        mu_prior: float = 0.0,
        std_prior: float = 0.1,
        mu_init: float = 0.0,
        rho_init: float = -7.0,
        rngs: Rngs = nnx.Rngs(0),
        **kwargs: Any,
    ) -> None:
        """
        Initialize a learnable Gaussian distribution module.

        Args:
            shape: Shape of the learnable parameters.
            mu_prior: Mean of the Gaussian prior.
            std_prior: Standard deviation of the prior.
            mu_init: Initial value for the learnable mean.
            rho_init: Initial value for the learnable rho.
            rngs: RNG container for parameter initialization.
            **kwargs: Extra arguments passed to the base class.

        Returns:
            None
        """

        super().__init__(**kwargs)

        self.shape = shape
        self.mu_prior = mu_prior
        self.std_prior = std_prior
        self.mu_init = mu_init
        self.rho_init = rho_init

        # Define initial mu and rho
        self.mu = nnx.Param(
            self.mu_init + 0.1 * jax.random.normal(rngs.params(), self.shape)
        )
        self.rho = nnx.Param(
            self.rho_init + 0.1 * jax.random.normal(rngs.params(), self.shape)
        )

    def sample(self, rngs: Rngs = nnx.Rngs(0)) -> jax.Array:
        """
        Draw a sample from the Gaussian distribution.

        Args:
            rngs: RNG container used for sampling.

        Returns:
            jax.Array: A sample drawn from the distribution.

        Notes:
            Sampling is reproducible with the same RNG.
        """

        # Compute epsilon and sigma
        eps: jax.Array = jax.random.normal(rngs.params(), self.rho.shape)
        sigma: jax.Array = jnp.log1p(jnp.exp(jnp.asarray(self.rho)))

        return self.mu + sigma * eps

    def log_prob(self, x: Optional[jax.Array] = None) -> jax.Array:
        """
        Compute the log-probability of a given sample. If no
        sample is provided, one is drawn internally.

        Args:
            x: Optional input sample to evaluate. If None,
                a new sample is drawn from the distribution.

        Returns:
            jax.Array: Scalar log-probability value.

        Notes:
            Supports both user-supplied and internally drawn
            samples.
        """

        # Sample if x is None
        if x is None:
            x = self.sample()

        # Define pi variable
        pi: jax.Array = jnp.acos(jnp.zeros(1)) * 2

        # Compute log priors
        log_prior = (
            -jnp.log(jnp.sqrt(2 * pi))
            - jnp.log(self.std_prior)
            - (((x - self.mu_prior) ** 2) / (2 * self.std_prior**2))
            - 0.5
        )

        # Compute sigma
        sigma: jax.Array = jnp.log1p(jnp.exp(jnp.asarray(self.rho)))

        # Compute log posteriors
        log_posteriors = (
            -jnp.log(jnp.sqrt(2 * pi))
            - jnp.log(sigma)
            - (((x - self.mu) ** 2) / (2 * sigma**2))
            - 0.5
        )

        # Compute final log probs
        log_probs = log_posteriors.sum() - log_prior.sum()

        return log_probs

    @property
    def num_params(self) -> int:
        """
        Return the number of learnable parameters in the
        distribution.

        Returns:
            int: Total count of learnable parameters.
        """

        return len(self.mu.reshape(-1))

    def __call__(self) -> jax.Array:
        """
        Perform a forward pass by drawing a sample.

        Returns:
            jax.Array: A sample from the distribution.
        """

        return self.sample()

1.2.1 num_params property

Return the number of learnable parameters in the distribution.

Returns:

Name Type Description
int int

Total count of learnable parameters.

1.2.2 __call__()

Perform a forward pass by drawing a sample.

Returns:

Type Description
Array

jax.Array: A sample from the distribution.

Source code in illia/distributions/jax/gaussian.py
149
150
151
152
153
154
155
156
157
def __call__(self) -> jax.Array:
    """
    Perform a forward pass by drawing a sample.

    Returns:
        jax.Array: A sample from the distribution.
    """

    return self.sample()

1.2.3 __init__(shape, mu_prior=0.0, std_prior=0.1, mu_init=0.0, rho_init=-7.0, rngs=nnx.Rngs(0), **kwargs)

Initialize a learnable Gaussian distribution module.

Parameters:

Name Type Description Default
shape tuple[int, ...]

Shape of the learnable parameters.

required
mu_prior float

Mean of the Gaussian prior.

0.0
std_prior float

Standard deviation of the prior.

0.1
mu_init float

Initial value for the learnable mean.

0.0
rho_init float

Initial value for the learnable rho.

-7.0
rngs Rngs

RNG container for parameter initialization.

Rngs(0)
**kwargs Any

Extra arguments passed to the base class.

{}

Returns:

Type Description
None

None

Source code in illia/distributions/jax/gaussian.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def __init__(
    self,
    shape: tuple[int, ...],
    mu_prior: float = 0.0,
    std_prior: float = 0.1,
    mu_init: float = 0.0,
    rho_init: float = -7.0,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Initialize a learnable Gaussian distribution module.

    Args:
        shape: Shape of the learnable parameters.
        mu_prior: Mean of the Gaussian prior.
        std_prior: Standard deviation of the prior.
        mu_init: Initial value for the learnable mean.
        rho_init: Initial value for the learnable rho.
        rngs: RNG container for parameter initialization.
        **kwargs: Extra arguments passed to the base class.

    Returns:
        None
    """

    super().__init__(**kwargs)

    self.shape = shape
    self.mu_prior = mu_prior
    self.std_prior = std_prior
    self.mu_init = mu_init
    self.rho_init = rho_init

    # Define initial mu and rho
    self.mu = nnx.Param(
        self.mu_init + 0.1 * jax.random.normal(rngs.params(), self.shape)
    )
    self.rho = nnx.Param(
        self.rho_init + 0.1 * jax.random.normal(rngs.params(), self.shape)
    )

1.2.4 log_prob(x=None)

Compute the log-probability of a given sample. If no sample is provided, one is drawn internally.

Parameters:

Name Type Description Default
x Optional[Array]

Optional input sample to evaluate. If None, a new sample is drawn from the distribution.

None

Returns:

Type Description
Array

jax.Array: Scalar log-probability value.

Notes

Supports both user-supplied and internally drawn samples.

Source code in illia/distributions/jax/gaussian.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def log_prob(self, x: Optional[jax.Array] = None) -> jax.Array:
    """
    Compute the log-probability of a given sample. If no
    sample is provided, one is drawn internally.

    Args:
        x: Optional input sample to evaluate. If None,
            a new sample is drawn from the distribution.

    Returns:
        jax.Array: Scalar log-probability value.

    Notes:
        Supports both user-supplied and internally drawn
        samples.
    """

    # Sample if x is None
    if x is None:
        x = self.sample()

    # Define pi variable
    pi: jax.Array = jnp.acos(jnp.zeros(1)) * 2

    # Compute log priors
    log_prior = (
        -jnp.log(jnp.sqrt(2 * pi))
        - jnp.log(self.std_prior)
        - (((x - self.mu_prior) ** 2) / (2 * self.std_prior**2))
        - 0.5
    )

    # Compute sigma
    sigma: jax.Array = jnp.log1p(jnp.exp(jnp.asarray(self.rho)))

    # Compute log posteriors
    log_posteriors = (
        -jnp.log(jnp.sqrt(2 * pi))
        - jnp.log(sigma)
        - (((x - self.mu) ** 2) / (2 * sigma**2))
        - 0.5
    )

    # Compute final log probs
    log_probs = log_posteriors.sum() - log_prior.sum()

    return log_probs

1.2.5 sample(rngs=nnx.Rngs(0))

Draw a sample from the Gaussian distribution.

Parameters:

Name Type Description Default
rngs Rngs

RNG container used for sampling.

Rngs(0)

Returns:

Type Description
Array

jax.Array: A sample drawn from the distribution.

Notes

Sampling is reproducible with the same RNG.

Source code in illia/distributions/jax/gaussian.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def sample(self, rngs: Rngs = nnx.Rngs(0)) -> jax.Array:
    """
    Draw a sample from the Gaussian distribution.

    Args:
        rngs: RNG container used for sampling.

    Returns:
        jax.Array: A sample drawn from the distribution.

    Notes:
        Sampling is reproducible with the same RNG.
    """

    # Compute epsilon and sigma
    eps: jax.Array = jax.random.normal(rngs.params(), self.rho.shape)
    sigma: jax.Array = jnp.log1p(jnp.exp(jnp.asarray(self.rho)))

    return self.mu + sigma * eps