Skip to content

1. Distributions

1.1 gaussian

This module contains the code for the Gaussian distribution.

1.1.1 GaussianDistribution(shape, mu_prior=0.0, std_prior=0.1, mu_init=0.0, rho_init=-7.0, rngs=nnx.Rngs(0))

This is the class to implement a learnable Gaussian distribution.

Constructor for GaussianDistribution.

Parameters:

Name Type Description Default
shape tuple[int, ...]

The shape of the parameters.

required
mu_prior float

The mean prior value.

0.0
std_prior float

The standard deviation prior value.

0.1
mu_init float

The initial mean value.

0.0
rho_init float

The initial rho value, which affects the initial standard deviation.

-7.0
rngs Rngs

Nnx rng container. Defaults to nnx.Rngs(0).

Rngs(0)
Source code in illia/distributions/jax/gaussian.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    shape: tuple[int, ...],
    mu_prior: float = 0.0,
    std_prior: float = 0.1,
    mu_init: float = 0.0,
    rho_init: float = -7.0,
    rngs: Rngs = nnx.Rngs(0),
) -> None:
    """
    Constructor for GaussianDistribution.

    Args:
        shape: The shape of the parameters.
        mu_prior: The mean prior value.
        std_prior: The standard deviation prior value.
        mu_init: The initial mean value.
        rho_init: The initial rho value, which affects the initial
            standard deviation.
        rngs: Nnx rng container. Defaults to nnx.Rngs(0).
    """

    # Call super-class constructor
    super().__init__()

    # Define priors
    self.mu_prior = mu_prior
    self.std_prior = std_prior

    # Define initial mu and rho
    self.mu = nnx.Param(
        mu_init + rho_init * jax.random.normal(rngs.params(), shape)
    )
    self.rho = nnx.Param(
        mu_init + rho_init * jax.random.normal(rngs.params(), shape)
    )

1.1.1.1 num_params property

Returns the number of parameters in the module.

Returns:

Type Description
int

The number of parameters as an integer.

1.1.1.2 __call__()

Performs the forward pass of the module.

Returns:

Type Description
Array

A sampled JAX array.

Source code in illia/distributions/jax/gaussian.py
77
78
79
80
81
82
83
84
85
def __call__(self) -> jax.Array:
    """
    Performs the forward pass of the module.

    Returns:
        A sampled JAX array.
    """

    return self.sample()

1.1.1.3 log_prob(x=None)

Computes the log probability of a given sample.

Parameters:

Name Type Description Default
x Optional[Array]

An optional sampled array. If None, a sample is generated.

None

Returns:

Type Description
Array

The log probability of the sample as a JAX array.

Source code in illia/distributions/jax/gaussian.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def log_prob(self, x: Optional[jax.Array] = None) -> jax.Array:
    """
    Computes the log probability of a given sample.

    Args:
        x: An optional sampled array. If None, a sample is
            generated.

    Returns:
        The log probability of the sample as a JAX array.
    """

    # Sample if x is None
    if x is None:
        x = self.sample()

    # Define pi variable
    pi: jax.Array = jnp.acos(jnp.zeros(1)) * 2

    # Compute log priors
    log_prior = (
        -jnp.log(jnp.sqrt(2 * pi))
        - jnp.log(self.std_prior)
        - (((x - self.mu_prior) ** 2) / (2 * self.std_prior**2))
        - 0.5
    )

    # Compute sigma
    sigma: jax.Array = jnp.log1p(jnp.exp(self.rho))  # type: ignore

    # Compute log posteriors
    log_posteriors = (
        -jnp.log(jnp.sqrt(2 * pi))
        - jnp.log(sigma)
        - (((x - self.mu) ** 2) / (2 * sigma**2))
        - 0.5
    )

    # Compute final log probs
    log_probs = log_posteriors.sum() - log_prior.sum()

    return log_probs

1.1.1.4 sample(rngs=nnx.Rngs(0))

Samples from the distribution using the current parameters.

Parameters:

Name Type Description Default
rngs Rngs

Nnx rng container. Defaults to nnx.Rngs(0).

Rngs(0)

Returns:

Type Description
Array

A sampled JAX array.

Source code in illia/distributions/jax/gaussian.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def sample(self, rngs: Rngs = nnx.Rngs(0)) -> jax.Array:
    """
    Samples from the distribution using the current parameters.

    Args:
        rngs: Nnx rng container. Defaults to nnx.Rngs(0).

    Returns:
        A sampled JAX array.
    """

    # Compute epsilon and sigma
    eps: jax.Array = jax.random.normal(rngs.params(), self.rho.shape)
    sigma: jax.Array = jnp.log1p(jnp.exp(self.rho))  # type: ignore

    return self.mu + sigma * eps