3. Skip to content

3. Neural Network Layers

3.1 BayesianModule

Abstract base for Bayesian-aware modules in JAX. Provides mechanisms to track if a module is Bayesian and control parameter updates through freezing/unfreezing.

Notes

All derived classes must implement freeze and kl_cost to handle parameter management and compute the KL divergence cost.

Source code in illia/nn/jax/base.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class BayesianModule(nnx.Module, ABC):
    """
    Abstract base for Bayesian-aware modules in JAX.
    Provides mechanisms to track if a module is Bayesian and control
    parameter updates through freezing/unfreezing.

    Notes:
        All derived classes must implement `freeze` and `kl_cost` to
        handle parameter management and compute the KL divergence cost.
    """

    def __init__(self, **kwargs: Any) -> None:
        """
        Initialize the Bayesian module with default flags.
        Sets `frozen` to False and `is_bayesian` to True.

        Args:
            **kwargs: Extra arguments passed to the base class.

        Returns:
            None.
        """

        super().__init__(**kwargs)

        self.frozen: bool = False
        self.is_bayesian: bool = True

    @abstractmethod
    def freeze(self) -> None:
        """
        Freeze the module's parameters to stop gradient computation.
        If weights or biases are not sampled yet, they are sampled first.
        Once frozen, parameters are not resampled or updated.

        Returns:
            None.

        Notes:
            Must be implemented by all subclasses.
        """

    def unfreeze(self) -> None:
        """
        Unfreeze the module by setting its `frozen` flag to False.
        Allows parameters to be sampled and updated again.

        Returns:
            None.
        """

        self.frozen = False

    @abstractmethod
    def kl_cost(self) -> tuple[jax.Array, int]:
        """
        Compute the KL divergence cost for all Bayesian parameters.

        Returns:
            tuple[jax.Array, int]: A tuple containing the KL divergence
                cost and the total number of parameters in the layer.

        Notes:
            Must be implemented by all subclasses.
        """

3.1.1 __init__(**kwargs)

Initialize the Bayesian module with default flags. Sets frozen to False and is_bayesian to True.

Parameters:

Name Type Description Default
**kwargs Any

Extra arguments passed to the base class.

{}

Returns:

Type Description
None

None.

Source code in illia/nn/jax/base.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def __init__(self, **kwargs: Any) -> None:
    """
    Initialize the Bayesian module with default flags.
    Sets `frozen` to False and `is_bayesian` to True.

    Args:
        **kwargs: Extra arguments passed to the base class.

    Returns:
        None.
    """

    super().__init__(**kwargs)

    self.frozen: bool = False
    self.is_bayesian: bool = True

3.1.2 freeze() abstractmethod

Freeze the module's parameters to stop gradient computation. If weights or biases are not sampled yet, they are sampled first. Once frozen, parameters are not resampled or updated.

Returns:

Type Description
None

None.

Notes

Must be implemented by all subclasses.

Source code in illia/nn/jax/base.py
38
39
40
41
42
43
44
45
46
47
48
49
50
@abstractmethod
def freeze(self) -> None:
    """
    Freeze the module's parameters to stop gradient computation.
    If weights or biases are not sampled yet, they are sampled first.
    Once frozen, parameters are not resampled or updated.

    Returns:
        None.

    Notes:
        Must be implemented by all subclasses.
    """

3.1.3 kl_cost() abstractmethod

Compute the KL divergence cost for all Bayesian parameters.

Returns:

Type Description
tuple[Array, int]

tuple[jax.Array, int]: A tuple containing the KL divergence cost and the total number of parameters in the layer.

Notes

Must be implemented by all subclasses.

Source code in illia/nn/jax/base.py
63
64
65
66
67
68
69
70
71
72
73
74
@abstractmethod
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Compute the KL divergence cost for all Bayesian parameters.

    Returns:
        tuple[jax.Array, int]: A tuple containing the KL divergence
            cost and the total number of parameters in the layer.

    Notes:
        Must be implemented by all subclasses.
    """

3.1.4 unfreeze()

Unfreeze the module by setting its frozen flag to False. Allows parameters to be sampled and updated again.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/base.py
52
53
54
55
56
57
58
59
60
61
def unfreeze(self) -> None:
    """
    Unfreeze the module by setting its `frozen` flag to False.
    Allows parameters to be sampled and updated again.

    Returns:
        None.
    """

    self.frozen = False

3.2 Conv1d

Bayesian 1D convolutional layer with optional weight and bias priors. Behaves like a standard Conv1d but treats weights and bias as random variables sampled from specified distributions. Parameters become fixed when the layer is frozen.

Source code in illia/nn/jax/conv1d.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
class Conv1d(BayesianModule):
    """
    Bayesian 1D convolutional layer with optional weight and bias priors.
    Behaves like a standard Conv1d but treats weights and bias as random
    variables sampled from specified distributions. Parameters become fixed
    when the layer is frozen.
    """

    bias_distribution: Optional[GaussianDistribution] = None
    bias: Optional[nnx.Param] = None

    def __init__(
        self,
        input_channels: int,
        output_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        weights_distribution: Optional[GaussianDistribution] = None,
        bias_distribution: Optional[GaussianDistribution] = None,
        use_bias: bool = True,
        rngs: Rngs = nnx.Rngs(0),
        **kwargs: Any,
    ) -> None:
        """
        Initializes a Bayesian 1D convolutional layer.

        Args:
            input_channels: Number of input feature channels.
            output_channels: Number of output feature channels.
            kernel_size: Size of the convolution kernel.
            stride: Stride of the convolution operation.
            padding: Amount of zero-padding on both sides.
            dilation: Spacing between kernel elements.
            groups: Number of blocked connections between input and output.
            weights_distribution: Distribution to initialize weights.
            bias_distribution: Distribution to initialize bias.
            use_bias: Whether to include a bias term.
            rngs: Random number generators for reproducibility.
            **kwargs: Extra arguments passed to the base class.

        Returns:
            None.

        Notes:
            Gaussian distributions are used by default if none are
            provided.
        """

        super().__init__(**kwargs)

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.use_bias = use_bias
        self.rngs = rngs

        # Set weights prior
        if weights_distribution is None:
            self.weights_distribution = GaussianDistribution(
                shape=(
                    self.output_channels,
                    self.input_channels // self.groups,
                    self.kernel_size,
                ),
                rngs=self.rngs,
            )
        else:
            self.weights_distribution = weights_distribution

        # Set bias prior
        if self.use_bias:
            if bias_distribution is None:
                self.bias_distribution = GaussianDistribution(
                    shape=(self.output_channels,),
                    rngs=self.rngs,
                )
            else:
                self.bias_distribution = bias_distribution
        else:
            self.bias_distribution = None

        # Sample initial weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample initial bias only if using bias
        if self.use_bias and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
        else:
            self.bias = None

    def freeze(self) -> None:
        """
        Freeze the module's parameters to stop gradient computation.
        If weights or biases are not sampled yet, they are sampled first.
        Once frozen, parameters are not resampled or updated.

        Returns:
            None.
        """

        # Set indicator
        self.frozen = True

        # Sample weights if they are undefined
        if self.weights is None:
            self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias if they are undefined and bias is used
        if self.use_bias and self.bias is None and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

        # Stop gradient computation
        self.weights = jax.lax.stop_gradient(self.weights)
        if self.use_bias:
            self.bias = jax.lax.stop_gradient(self.bias)

    def kl_cost(self) -> tuple[jax.Array, int]:
        """
        Compute the KL divergence cost for all Bayesian parameters.

        Returns:
            tuple[jax.Array, int]: A tuple containing the KL divergence
                cost and the total number of parameters in the layer.
        """

        # Compute log probs for weights
        log_probs: jax.Array = self.weights_distribution.log_prob(
            jnp.asarray(self.weights)
        )

        # Add bias log probs only if using bias
        if (
            self.use_bias
            and self.bias is not None
            and self.bias_distribution is not None
        ):
            log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

        # Compute number of parameters
        num_params: int = self.weights_distribution.num_params
        if self.use_bias and self.bias_distribution is not None:
            num_params += self.bias_distribution.num_params

        return log_probs, num_params

    def __call__(self, inputs: jax.Array) -> jax.Array:
        """
        Performs a forward pass through the Bayesian Convolution 1D
        layer. If the layer is not frozen, it samples weights and bias
        from their respective distributions. If the layer is frozen
        and the weights or bias are not initialized, it also performs
        sampling.

        Args:
            inputs: Input tensor to the layer with shape
                (batch, channels, length).

        Returns:
            Output array after convolution with optional bias added.

        Raises:
            ValueError: If the layer is frozen but weights or bias are
                undefined.
        """

        # Sample if model not frozen
        if not self.frozen:
            # Sample weights
            self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

            # Sample bias only if using bias
            if self.use_bias and self.bias_distribution is not None:
                self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
        elif self.weights is None or (self.use_bias and self.bias is None):
            raise ValueError(
                "Module has been frozen with undefined weights and/or bias."
            )

        # Compute outputs
        outputs = jax.lax.conv_general_dilated(
            lhs=inputs,
            rhs=jnp.asarray(self.weights),
            window_strides=[self.stride],
            padding=[(self.padding, self.padding)],
            lhs_dilation=[1],
            rhs_dilation=[self.dilation],
            dimension_numbers=(
                "NCH",  # Input
                "OIH",  # Kernel
                "NCH",  # Output
            ),
            feature_group_count=self.groups,
        )

        # Add bias only if using bias
        if self.use_bias and self.bias is not None:
            outputs += jnp.reshape(
                a=jnp.asarray(self.bias), shape=(1, self.output_channels, 1)
            )

        return outputs

3.2.1 __call__(inputs)

Performs a forward pass through the Bayesian Convolution 1D layer. If the layer is not frozen, it samples weights and bias from their respective distributions. If the layer is frozen and the weights or bias are not initialized, it also performs sampling.

Parameters:

Name Type Description Default
inputs Array

Input tensor to the layer with shape (batch, channels, length).

required

Returns:

Type Description
Array

Output array after convolution with optional bias added.

Raises:

Type Description
ValueError

If the layer is frozen but weights or bias are undefined.

Source code in illia/nn/jax/conv1d.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    Performs a forward pass through the Bayesian Convolution 1D
    layer. If the layer is not frozen, it samples weights and bias
    from their respective distributions. If the layer is frozen
    and the weights or bias are not initialized, it also performs
    sampling.

    Args:
        inputs: Input tensor to the layer with shape
            (batch, channels, length).

    Returns:
        Output array after convolution with optional bias added.

    Raises:
        ValueError: If the layer is frozen but weights or bias are
            undefined.
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias only if using bias
        if self.use_bias and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    elif self.weights is None or (self.use_bias and self.bias is None):
        raise ValueError(
            "Module has been frozen with undefined weights and/or bias."
        )

    # Compute outputs
    outputs = jax.lax.conv_general_dilated(
        lhs=inputs,
        rhs=jnp.asarray(self.weights),
        window_strides=[self.stride],
        padding=[(self.padding, self.padding)],
        lhs_dilation=[1],
        rhs_dilation=[self.dilation],
        dimension_numbers=(
            "NCH",  # Input
            "OIH",  # Kernel
            "NCH",  # Output
        ),
        feature_group_count=self.groups,
    )

    # Add bias only if using bias
    if self.use_bias and self.bias is not None:
        outputs += jnp.reshape(
            a=jnp.asarray(self.bias), shape=(1, self.output_channels, 1)
        )

    return outputs

3.2.2 __init__(input_channels, output_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, weights_distribution=None, bias_distribution=None, use_bias=True, rngs=nnx.Rngs(0), **kwargs)

Initializes a Bayesian 1D convolutional layer.

Parameters:

Name Type Description Default
input_channels int

Number of input feature channels.

required
output_channels int

Number of output feature channels.

required
kernel_size int

Size of the convolution kernel.

required
stride int

Stride of the convolution operation.

1
padding int

Amount of zero-padding on both sides.

0
dilation int

Spacing between kernel elements.

1
groups int

Number of blocked connections between input and output.

1
weights_distribution Optional[GaussianDistribution]

Distribution to initialize weights.

None
bias_distribution Optional[GaussianDistribution]

Distribution to initialize bias.

None
use_bias bool

Whether to include a bias term.

True
rngs Rngs

Random number generators for reproducibility.

Rngs(0)
**kwargs Any

Extra arguments passed to the base class.

{}

Returns:

Type Description
None

None.

Notes

Gaussian distributions are used by default if none are provided.

Source code in illia/nn/jax/conv1d.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def __init__(
    self,
    input_channels: int,
    output_channels: int,
    kernel_size: int,
    stride: int = 1,
    padding: int = 0,
    dilation: int = 1,
    groups: int = 1,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
    use_bias: bool = True,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Initializes a Bayesian 1D convolutional layer.

    Args:
        input_channels: Number of input feature channels.
        output_channels: Number of output feature channels.
        kernel_size: Size of the convolution kernel.
        stride: Stride of the convolution operation.
        padding: Amount of zero-padding on both sides.
        dilation: Spacing between kernel elements.
        groups: Number of blocked connections between input and output.
        weights_distribution: Distribution to initialize weights.
        bias_distribution: Distribution to initialize bias.
        use_bias: Whether to include a bias term.
        rngs: Random number generators for reproducibility.
        **kwargs: Extra arguments passed to the base class.

    Returns:
        None.

    Notes:
        Gaussian distributions are used by default if none are
        provided.
    """

    super().__init__(**kwargs)

    self.input_channels = input_channels
    self.output_channels = output_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.dilation = dilation
    self.groups = groups
    self.use_bias = use_bias
    self.rngs = rngs

    # Set weights prior
    if weights_distribution is None:
        self.weights_distribution = GaussianDistribution(
            shape=(
                self.output_channels,
                self.input_channels // self.groups,
                self.kernel_size,
            ),
            rngs=self.rngs,
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias prior
    if self.use_bias:
        if bias_distribution is None:
            self.bias_distribution = GaussianDistribution(
                shape=(self.output_channels,),
                rngs=self.rngs,
            )
        else:
            self.bias_distribution = bias_distribution
    else:
        self.bias_distribution = None

    # Sample initial weights
    self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample initial bias only if using bias
    if self.use_bias and self.bias_distribution is not None:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    else:
        self.bias = None

3.2.3 freeze()

Freeze the module's parameters to stop gradient computation. If weights or biases are not sampled yet, they are sampled first. Once frozen, parameters are not resampled or updated.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/conv1d.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def freeze(self) -> None:
    """
    Freeze the module's parameters to stop gradient computation.
    If weights or biases are not sampled yet, they are sampled first.
    Once frozen, parameters are not resampled or updated.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample bias if they are undefined and bias is used
    if self.use_bias and self.bias is None and self.bias_distribution is not None:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Stop gradient computation
    self.weights = jax.lax.stop_gradient(self.weights)
    if self.use_bias:
        self.bias = jax.lax.stop_gradient(self.bias)

3.2.4 kl_cost()

Compute the KL divergence cost for all Bayesian parameters.

Returns:

Type Description
tuple[Array, int]

tuple[jax.Array, int]: A tuple containing the KL divergence cost and the total number of parameters in the layer.

Source code in illia/nn/jax/conv1d.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Compute the KL divergence cost for all Bayesian parameters.

    Returns:
        tuple[jax.Array, int]: A tuple containing the KL divergence
            cost and the total number of parameters in the layer.
    """

    # Compute log probs for weights
    log_probs: jax.Array = self.weights_distribution.log_prob(
        jnp.asarray(self.weights)
    )

    # Add bias log probs only if using bias
    if (
        self.use_bias
        and self.bias is not None
        and self.bias_distribution is not None
    ):
        log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

    # Compute number of parameters
    num_params: int = self.weights_distribution.num_params
    if self.use_bias and self.bias_distribution is not None:
        num_params += self.bias_distribution.num_params

    return log_probs, num_params

3.3 Conv2d

Bayesian 2D convolutional layer with optional weight and bias priors. Behaves like a standard Conv2d but treats weights and bias as random variables sampled from specified distributions. Parameters become fixed when the layer is frozen.

Source code in illia/nn/jax/conv2d.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class Conv2d(BayesianModule):
    """
    Bayesian 2D convolutional layer with optional weight and bias priors.
    Behaves like a standard Conv2d but treats weights and bias as random
    variables sampled from specified distributions. Parameters become fixed
    when the layer is frozen.
    """

    bias_distribution: Optional[GaussianDistribution] = None
    bias: Optional[nnx.Param] = None

    def __init__(
        self,
        input_channels: int,
        output_channels: int,
        kernel_size: int | tuple[int, int],
        stride: tuple[int, int] = (1, 1),
        padding: tuple[int, int] = (0, 0),
        dilation: tuple[int, int] = (1, 1),
        groups: int = 1,
        weights_distribution: Optional[GaussianDistribution] = None,
        bias_distribution: Optional[GaussianDistribution] = None,
        use_bias: bool = True,
        rngs: Rngs = nnx.Rngs(0),
        **kwargs: Any,
    ) -> None:
        """
        Initializes a Bayesian 2D convolutional layer.

        Args:
            input_channels: Number of input feature channels.
            output_channels: Number of output feature channels.
            kernel_size: Convolution kernel size. Int is converted to tuple.
            stride: Stride of the convolution operation.
            padding: Tuple specifying zero-padding for height and width.
            dilation: Spacing between kernel elements.
            groups: Number of blocked connections between input and output.
            weights_distribution: Distribution to initialize weights.
            bias_distribution: Distribution to initialize bias.
            use_bias: Whether to include a bias term.
            rngs: Random number generators for reproducibility.
            **kwargs: Extra arguments passed to the base class.

        Returns:
            None.

        Notes:
            Gaussian distributions are used by default if none are
            provided.
        """

        super().__init__(**kwargs)

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.use_bias = use_bias
        self.rngs = rngs

        # Set weights distribution
        if weights_distribution is None:
            # Extend kernel if we only have 1 value
            if isinstance(self.kernel_size, int):
                self.kernel_size = (self.kernel_size, self.kernel_size)

            self.weights_distribution: GaussianDistribution = GaussianDistribution(
                shape=(
                    self.output_channels,
                    self.input_channels // self.groups,
                    *self.kernel_size,
                ),
                rngs=self.rngs,
            )
        else:
            self.weights_distribution = weights_distribution

        # Set bias prior
        if self.use_bias:
            if bias_distribution is None:
                # Define weights distribution
                self.bias_distribution = GaussianDistribution(
                    shape=(self.output_channels,),
                    rngs=self.rngs,
                )
            else:
                self.bias_distribution = bias_distribution
        else:
            self.bias_distribution = None

        # Sample initial weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample initial bias only if using bias
        if self.use_bias and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
        else:
            self.bias = None

    def freeze(self) -> None:
        """
        Freeze the module's parameters to stop gradient computation.
        If weights or biases are not sampled yet, they are sampled first.
        Once frozen, parameters are not resampled or updated.

        Returns:
            None.
        """

        # Set indicator
        self.frozen = True

        # Sample weights if they are undefined
        if self.weights is None:
            self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias if they are undefined and bias is used
        if self.use_bias and self.bias is None and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

        # Stop gradient computation
        self.weights = jax.lax.stop_gradient(self.weights)
        if self.use_bias:
            self.bias = jax.lax.stop_gradient(self.bias)

    def kl_cost(self) -> tuple[jax.Array, int]:
        """
        Compute the KL divergence cost for all Bayesian parameters.

        Returns:
            tuple[jax.Array, int]: A tuple containing the KL divergence
                cost and the total number of parameters in the layer.
        """

        # Compute log probs for weights
        log_probs: jax.Array = self.weights_distribution.log_prob(
            jnp.asarray(self.weights)
        )

        # Add bias log probs only if using bias
        if (
            self.use_bias
            and self.bias is not None
            and self.bias_distribution is not None
        ):
            log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

        # Compute number of parameters
        num_params: int = self.weights_distribution.num_params
        if self.use_bias and self.bias_distribution is not None:
            num_params += self.bias_distribution.num_params

        return log_probs, num_params

    def __call__(self, inputs: jax.Array) -> jax.Array:
        """
        Performs a forward pass through the Bayesian Convolution 2D
        layer. If the layer is not frozen, it samples weights and bias
        from their respective distributions. If the layer is frozen
        and the weights or bias are not initialized, it also performs
        sampling.

        Args:
            inputs: Input array with shape (batch, channels, height,
                width).

        Returns:
            Output array after convolution with optional bias addition.

        Raises:
            ValueError: If the layer is frozen but weights or bias are
                undefined.
        """

        # Sample if model not frozen
        if not self.frozen:
            # Sample weights
            self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

            # Sample bias only if using bias
            if self.use_bias and self.bias_distribution is not None:
                self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
        elif self.weights is None or (self.use_bias and self.bias is None):
            raise ValueError(
                "Module has been frozen with undefined weights and/or bias."
            )

        # Compute ouputs
        outputs = jax.lax.conv_general_dilated(
            lhs=inputs,
            rhs=jnp.asarray(self.weights),
            window_strides=self.stride,
            padding=[self.padding, self.padding],
            lhs_dilation=[1, 1],
            rhs_dilation=self.dilation,
            dimension_numbers=(
                "NCHW",  # Input
                "OIHW",  # Kernel
                "NCHW",  # Output
            ),
            feature_group_count=self.groups,
        )

        # Add bias only if using bias
        if self.use_bias and self.bias is not None:
            outputs += jnp.reshape(
                a=jnp.asarray(self.bias), shape=(1, self.output_channels, 1, 1)
            )

        return outputs

3.3.1 __call__(inputs)

Performs a forward pass through the Bayesian Convolution 2D layer. If the layer is not frozen, it samples weights and bias from their respective distributions. If the layer is frozen and the weights or bias are not initialized, it also performs sampling.

Parameters:

Name Type Description Default
inputs Array

Input array with shape (batch, channels, height, width).

required

Returns:

Type Description
Array

Output array after convolution with optional bias addition.

Raises:

Type Description
ValueError

If the layer is frozen but weights or bias are undefined.

Source code in illia/nn/jax/conv2d.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    Performs a forward pass through the Bayesian Convolution 2D
    layer. If the layer is not frozen, it samples weights and bias
    from their respective distributions. If the layer is frozen
    and the weights or bias are not initialized, it also performs
    sampling.

    Args:
        inputs: Input array with shape (batch, channels, height,
            width).

    Returns:
        Output array after convolution with optional bias addition.

    Raises:
        ValueError: If the layer is frozen but weights or bias are
            undefined.
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias only if using bias
        if self.use_bias and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    elif self.weights is None or (self.use_bias and self.bias is None):
        raise ValueError(
            "Module has been frozen with undefined weights and/or bias."
        )

    # Compute ouputs
    outputs = jax.lax.conv_general_dilated(
        lhs=inputs,
        rhs=jnp.asarray(self.weights),
        window_strides=self.stride,
        padding=[self.padding, self.padding],
        lhs_dilation=[1, 1],
        rhs_dilation=self.dilation,
        dimension_numbers=(
            "NCHW",  # Input
            "OIHW",  # Kernel
            "NCHW",  # Output
        ),
        feature_group_count=self.groups,
    )

    # Add bias only if using bias
    if self.use_bias and self.bias is not None:
        outputs += jnp.reshape(
            a=jnp.asarray(self.bias), shape=(1, self.output_channels, 1, 1)
        )

    return outputs

3.3.2 __init__(input_channels, output_channels, kernel_size, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, weights_distribution=None, bias_distribution=None, use_bias=True, rngs=nnx.Rngs(0), **kwargs)

Initializes a Bayesian 2D convolutional layer.

Parameters:

Name Type Description Default
input_channels int

Number of input feature channels.

required
output_channels int

Number of output feature channels.

required
kernel_size int | tuple[int, int]

Convolution kernel size. Int is converted to tuple.

required
stride tuple[int, int]

Stride of the convolution operation.

(1, 1)
padding tuple[int, int]

Tuple specifying zero-padding for height and width.

(0, 0)
dilation tuple[int, int]

Spacing between kernel elements.

(1, 1)
groups int

Number of blocked connections between input and output.

1
weights_distribution Optional[GaussianDistribution]

Distribution to initialize weights.

None
bias_distribution Optional[GaussianDistribution]

Distribution to initialize bias.

None
use_bias bool

Whether to include a bias term.

True
rngs Rngs

Random number generators for reproducibility.

Rngs(0)
**kwargs Any

Extra arguments passed to the base class.

{}

Returns:

Type Description
None

None.

Notes

Gaussian distributions are used by default if none are provided.

Source code in illia/nn/jax/conv2d.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def __init__(
    self,
    input_channels: int,
    output_channels: int,
    kernel_size: int | tuple[int, int],
    stride: tuple[int, int] = (1, 1),
    padding: tuple[int, int] = (0, 0),
    dilation: tuple[int, int] = (1, 1),
    groups: int = 1,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
    use_bias: bool = True,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Initializes a Bayesian 2D convolutional layer.

    Args:
        input_channels: Number of input feature channels.
        output_channels: Number of output feature channels.
        kernel_size: Convolution kernel size. Int is converted to tuple.
        stride: Stride of the convolution operation.
        padding: Tuple specifying zero-padding for height and width.
        dilation: Spacing between kernel elements.
        groups: Number of blocked connections between input and output.
        weights_distribution: Distribution to initialize weights.
        bias_distribution: Distribution to initialize bias.
        use_bias: Whether to include a bias term.
        rngs: Random number generators for reproducibility.
        **kwargs: Extra arguments passed to the base class.

    Returns:
        None.

    Notes:
        Gaussian distributions are used by default if none are
        provided.
    """

    super().__init__(**kwargs)

    self.input_channels = input_channels
    self.output_channels = output_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.dilation = dilation
    self.groups = groups
    self.use_bias = use_bias
    self.rngs = rngs

    # Set weights distribution
    if weights_distribution is None:
        # Extend kernel if we only have 1 value
        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)

        self.weights_distribution: GaussianDistribution = GaussianDistribution(
            shape=(
                self.output_channels,
                self.input_channels // self.groups,
                *self.kernel_size,
            ),
            rngs=self.rngs,
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias prior
    if self.use_bias:
        if bias_distribution is None:
            # Define weights distribution
            self.bias_distribution = GaussianDistribution(
                shape=(self.output_channels,),
                rngs=self.rngs,
            )
        else:
            self.bias_distribution = bias_distribution
    else:
        self.bias_distribution = None

    # Sample initial weights
    self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample initial bias only if using bias
    if self.use_bias and self.bias_distribution is not None:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    else:
        self.bias = None

3.3.3 freeze()

Freeze the module's parameters to stop gradient computation. If weights or biases are not sampled yet, they are sampled first. Once frozen, parameters are not resampled or updated.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/conv2d.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def freeze(self) -> None:
    """
    Freeze the module's parameters to stop gradient computation.
    If weights or biases are not sampled yet, they are sampled first.
    Once frozen, parameters are not resampled or updated.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample bias if they are undefined and bias is used
    if self.use_bias and self.bias is None and self.bias_distribution is not None:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Stop gradient computation
    self.weights = jax.lax.stop_gradient(self.weights)
    if self.use_bias:
        self.bias = jax.lax.stop_gradient(self.bias)

3.3.4 kl_cost()

Compute the KL divergence cost for all Bayesian parameters.

Returns:

Type Description
tuple[Array, int]

tuple[jax.Array, int]: A tuple containing the KL divergence cost and the total number of parameters in the layer.

Source code in illia/nn/jax/conv2d.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Compute the KL divergence cost for all Bayesian parameters.

    Returns:
        tuple[jax.Array, int]: A tuple containing the KL divergence
            cost and the total number of parameters in the layer.
    """

    # Compute log probs for weights
    log_probs: jax.Array = self.weights_distribution.log_prob(
        jnp.asarray(self.weights)
    )

    # Add bias log probs only if using bias
    if (
        self.use_bias
        and self.bias is not None
        and self.bias_distribution is not None
    ):
        log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

    # Compute number of parameters
    num_params: int = self.weights_distribution.num_params
    if self.use_bias and self.bias_distribution is not None:
        num_params += self.bias_distribution.num_params

    return log_probs, num_params

3.4 Embedding

Bayesian embedding layer with optional padding and max-norm constraints. Each embedding vector is sampled from a specified weight distribution. If the layer is frozen, embeddings are fixed and gradients are stopped.

Source code in illia/nn/jax/embedding.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class Embedding(BayesianModule):
    """
    Bayesian embedding layer with optional padding and max-norm constraints.
    Each embedding vector is sampled from a specified weight distribution.
    If the layer is frozen, embeddings are fixed and gradients are stopped.
    """

    def __init__(
        self,
        num_embeddings: int,
        embeddings_dim: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        weights_distribution: Optional[GaussianDistribution] = None,
        rngs: Rngs = nnx.Rngs(0),
        **kwargs: Any,
    ) -> None:
        """
        Initialize a Bayesian embedding layer with optional constraints.
        Sets up the embedding weight distribution and samples initial values.

        Args:
            num_embeddings: Size of the embedding dictionary.
            embeddings_dim: Dimension of each embedding vector.
            padding_idx: Index whose embeddings are ignored in gradient.
            max_norm: Maximum norm for each embedding vector.
            norm_type: p value for the p-norm in max_norm option.
            scale_grad_by_freq: Scale gradients by inverse word frequency.
            weights_distribution: Distribution to initialize embeddings.
            rngs: Random number generators for reproducibility.
            **kwargs: Extra arguments passed to the base class.

        Returns:
            None.

        Notes:
            Gaussian distributions are used by default if none are
            provided.
        """

        super().__init__(**kwargs)

        self.num_embeddings = num_embeddings
        self.embeddings_dim = embeddings_dim
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.rngs = rngs

        # Set weights distribution
        if weights_distribution is None:
            self.weights_distribution = GaussianDistribution(
                (self.num_embeddings, self.embeddings_dim)
            )
        else:
            self.weights_distribution = weights_distribution

        # Sample initial weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    def freeze(self) -> None:
        """
        Freeze the module's parameters to stop gradient computation.
        If weights or biases are not sampled yet, they are sampled first.
        Once frozen, parameters are not resampled or updated.

        Returns:
            None.
        """

        # Set indicator
        self.frozen = True

        # Sample weights if they are undefined
        if self.weights is None:
            self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Stop gradient computation
        self.weights = jax.lax.stop_gradient(self.weights)

    def kl_cost(self) -> tuple[jax.Array, int]:
        """
        Compute the KL divergence cost for all Bayesian parameters.

        Returns:
            tuple[jax.Array, int]: A tuple containing the KL divergence
                cost and the total number of parameters in the layer.
        """

        # Compute log probs for weights
        log_probs: jax.Array = self.weights_distribution.log_prob(
            jnp.asarray(self.weights)
        )

        # get number of parameters
        num_params: int = self.weights_distribution.num_params

        return log_probs, num_params

    def __call__(self, inputs: jax.Array) -> jax.Array:
        """
        Perform a forward pass using current embedding weights.

        Args:
            inputs: Array of indices into the embedding matrix.

        Returns:
            Array of shape [*, embeddings_dim] containing the embedding
            vectors corresponding to the input indices.

        Raises:
            ValueError: If the layer is frozen but weights are
                undefined.

        Notes:
            Embeddings at padding_idx are zeroed out, and vectors exceeding
            max_norm are renormalized if specified.
        """

        # Sample if model not frozen
        if not self.frozen:
            # Sample weights
            self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))
        elif self.weights is None:
            raise ValueError("Module has been frozen with undefined weights.")

        # Perform embedding lookup
        outputs = self.weights.value[inputs]

        # Apply padding_idx
        if self.padding_idx is not None:
            # Create mask for padding indices
            mask = inputs == self.padding_idx
            # Zero out embeddings for padding indices
            outputs = jnp.where(mask[..., None], 0.0, outputs)

        # Apply max_norm
        if self.max_norm is not None:
            norms = jnp.linalg.norm(outputs, axis=-1, ord=self.norm_type, keepdims=True)
            # Normalize vectors that exceed max_norm
            scale = jnp.minimum(1.0, self.max_norm / (norms + 1e-8))
            outputs = outputs * scale

        return outputs

3.4.1 __call__(inputs)

Perform a forward pass using current embedding weights.

Parameters:

Name Type Description Default
inputs Array

Array of indices into the embedding matrix.

required

Returns:

Type Description
Array

Array of shape [*, embeddings_dim] containing the embedding

Array

vectors corresponding to the input indices.

Raises:

Type Description
ValueError

If the layer is frozen but weights are undefined.

Notes

Embeddings at padding_idx are zeroed out, and vectors exceeding max_norm are renormalized if specified.

Source code in illia/nn/jax/embedding.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    Perform a forward pass using current embedding weights.

    Args:
        inputs: Array of indices into the embedding matrix.

    Returns:
        Array of shape [*, embeddings_dim] containing the embedding
        vectors corresponding to the input indices.

    Raises:
        ValueError: If the layer is frozen but weights are
            undefined.

    Notes:
        Embeddings at padding_idx are zeroed out, and vectors exceeding
        max_norm are renormalized if specified.
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))
    elif self.weights is None:
        raise ValueError("Module has been frozen with undefined weights.")

    # Perform embedding lookup
    outputs = self.weights.value[inputs]

    # Apply padding_idx
    if self.padding_idx is not None:
        # Create mask for padding indices
        mask = inputs == self.padding_idx
        # Zero out embeddings for padding indices
        outputs = jnp.where(mask[..., None], 0.0, outputs)

    # Apply max_norm
    if self.max_norm is not None:
        norms = jnp.linalg.norm(outputs, axis=-1, ord=self.norm_type, keepdims=True)
        # Normalize vectors that exceed max_norm
        scale = jnp.minimum(1.0, self.max_norm / (norms + 1e-8))
        outputs = outputs * scale

    return outputs

3.4.2 __init__(num_embeddings, embeddings_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, weights_distribution=None, rngs=nnx.Rngs(0), **kwargs)

Initialize a Bayesian embedding layer with optional constraints. Sets up the embedding weight distribution and samples initial values.

Parameters:

Name Type Description Default
num_embeddings int

Size of the embedding dictionary.

required
embeddings_dim int

Dimension of each embedding vector.

required
padding_idx Optional[int]

Index whose embeddings are ignored in gradient.

None
max_norm Optional[float]

Maximum norm for each embedding vector.

None
norm_type float

p value for the p-norm in max_norm option.

2.0
scale_grad_by_freq bool

Scale gradients by inverse word frequency.

False
weights_distribution Optional[GaussianDistribution]

Distribution to initialize embeddings.

None
rngs Rngs

Random number generators for reproducibility.

Rngs(0)
**kwargs Any

Extra arguments passed to the base class.

{}

Returns:

Type Description
None

None.

Notes

Gaussian distributions are used by default if none are provided.

Source code in illia/nn/jax/embedding.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __init__(
    self,
    num_embeddings: int,
    embeddings_dim: int,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: float = 2.0,
    scale_grad_by_freq: bool = False,
    weights_distribution: Optional[GaussianDistribution] = None,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Initialize a Bayesian embedding layer with optional constraints.
    Sets up the embedding weight distribution and samples initial values.

    Args:
        num_embeddings: Size of the embedding dictionary.
        embeddings_dim: Dimension of each embedding vector.
        padding_idx: Index whose embeddings are ignored in gradient.
        max_norm: Maximum norm for each embedding vector.
        norm_type: p value for the p-norm in max_norm option.
        scale_grad_by_freq: Scale gradients by inverse word frequency.
        weights_distribution: Distribution to initialize embeddings.
        rngs: Random number generators for reproducibility.
        **kwargs: Extra arguments passed to the base class.

    Returns:
        None.

    Notes:
        Gaussian distributions are used by default if none are
        provided.
    """

    super().__init__(**kwargs)

    self.num_embeddings = num_embeddings
    self.embeddings_dim = embeddings_dim
    self.padding_idx = padding_idx
    self.max_norm = max_norm
    self.norm_type = norm_type
    self.scale_grad_by_freq = scale_grad_by_freq
    self.rngs = rngs

    # Set weights distribution
    if weights_distribution is None:
        self.weights_distribution = GaussianDistribution(
            (self.num_embeddings, self.embeddings_dim)
        )
    else:
        self.weights_distribution = weights_distribution

    # Sample initial weights
    self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

3.4.3 freeze()

Freeze the module's parameters to stop gradient computation. If weights or biases are not sampled yet, they are sampled first. Once frozen, parameters are not resampled or updated.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/embedding.py
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def freeze(self) -> None:
    """
    Freeze the module's parameters to stop gradient computation.
    If weights or biases are not sampled yet, they are sampled first.
    Once frozen, parameters are not resampled or updated.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Stop gradient computation
    self.weights = jax.lax.stop_gradient(self.weights)

3.4.4 kl_cost()

Compute the KL divergence cost for all Bayesian parameters.

Returns:

Type Description
tuple[Array, int]

tuple[jax.Array, int]: A tuple containing the KL divergence cost and the total number of parameters in the layer.

Source code in illia/nn/jax/embedding.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Compute the KL divergence cost for all Bayesian parameters.

    Returns:
        tuple[jax.Array, int]: A tuple containing the KL divergence
            cost and the total number of parameters in the layer.
    """

    # Compute log probs for weights
    log_probs: jax.Array = self.weights_distribution.log_prob(
        jnp.asarray(self.weights)
    )

    # get number of parameters
    num_params: int = self.weights_distribution.num_params

    return log_probs, num_params

3.5 Linear

Bayesian linear (fully connected) layer with optional weight and bias priors. Functions like a standard linear layer but treats weights and bias as probabilistic variables. Freezing the layer fixes parameters and stops gradient computation.

Source code in illia/nn/jax/linear.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
class Linear(BayesianModule):
    """
    Bayesian linear (fully connected) layer with optional weight and bias
    priors. Functions like a standard linear layer but treats weights and
    bias as probabilistic variables. Freezing the layer fixes parameters
    and stops gradient computation.
    """

    bias_distribution: Optional[GaussianDistribution] = None
    bias: Optional[nnx.Param] = None

    def __init__(
        self,
        input_size: int,
        output_size: int,
        weights_distribution: Optional[GaussianDistribution] = None,
        bias_distribution: Optional[GaussianDistribution] = None,
        use_bias: bool = True,
        precision: PrecisionLike = None,
        dot_general: DotGeneralT = lax.dot_general,
        rngs: Rngs = nnx.Rngs(0),
        **kwargs: Any,
    ) -> None:
        """
        Initialize a Bayesian linear layer with optional priors for weights
        and bias. Samples initial parameter values from the specified
        distributions.

        Args:
            input_size: Number of input features.
            output_size: Number of output features.
            weights_distribution: Distribution for weights.
            bias_distribution: Distribution for bias.
            use_bias: Whether to include a bias term.
            precision: Precision for dot product computations.
            dot_general: Function for generalized dot products.
            rngs: Random number generators for reproducibility.
            **kwargs: Extra arguments passed to the base class.

        Returns:
            None.

        Notes:
            Gaussian distributions are used by default if none are
            provided.
        """

        super().__init__(**kwargs)

        self.input_size = input_size
        self.output_size = output_size
        self.use_bias = use_bias
        self.precision = precision
        self.dot_general = dot_general
        self.rngs = rngs

        # Set weights prior
        if weights_distribution is None:
            self.weights_distribution = GaussianDistribution(
                shape=(self.output_size, self.input_size), rngs=self.rngs
            )
        else:
            self.weights_distribution = weights_distribution

        # Set bias prior
        if self.use_bias:
            if bias_distribution is None:
                self.bias_distribution = GaussianDistribution(
                    shape=(self.output_size,), rngs=self.rngs
                )
            else:
                self.bias_distribution = bias_distribution
        else:
            self.bias_distribution = None

        # Sample initial weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample initial bias only if using bias
        if self.use_bias and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
        else:
            self.bias = None

    def freeze(self) -> None:
        """
        Freeze the module's parameters to stop gradient computation.
        If weights or biases are not sampled yet, they are sampled first.
        Once frozen, parameters are not resampled or updated.

        Returns:
            None.
        """

        # Set indicator
        self.frozen = True

        # Sample weights if they are undefined
        if self.weights is None:
            self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias if they are undefined and bias is used
        if self.use_bias and self.bias is None and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

        # Stop gradient computation
        self.weights = jax.lax.stop_gradient(self.weights)
        if self.use_bias:
            self.bias = jax.lax.stop_gradient(self.bias)

    def kl_cost(self) -> tuple[jax.Array, int]:
        """
        Compute the KL divergence cost for all Bayesian parameters.

        Returns:
            tuple[jax.Array, int]: A tuple containing the KL divergence
                cost and the total number of parameters in the layer.
        """

        # Compute log probs for weights
        log_probs: jax.Array = self.weights_distribution.log_prob(
            jnp.asarray(self.weights)
        )

        # Add bias log probs only if using bias
        if (
            self.use_bias
            and self.bias is not None
            and self.bias_distribution is not None
        ):
            log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

        # Compute number of parameters
        num_params: int = self.weights_distribution.num_params
        if self.use_bias and self.bias_distribution is not None:
            num_params += self.bias_distribution.num_params

        return log_probs, num_params

    def __call__(self, inputs: jax.Array) -> jax.Array:
        """
        Perform a forward pass using current weights and bias. Samples new
        parameters if the layer is not frozen.

        Args:
            inputs: Input array with shape [*, input_size].

        Returns:
            Output array with shape [*, output_size].

        Raises:
            ValueError: If the layer is frozen but weights or bias are
                undefined.
        """

        # Sample if model not frozen
        if not self.frozen:
            # Sample weights
            self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

            # Sample bias only if using bias
            if self.use_bias and self.bias_distribution is not None:
                self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
        elif self.weights is None or (self.use_bias and self.bias is None):
            raise ValueError(
                "Module has been frozen with undefined weights and/or bias."
            )

        # Compute outputs
        outputs = inputs @ self.weights.T

        # Add bias only if using bias
        if self.use_bias and self.bias is not None:
            outputs += jnp.reshape(
                jnp.asarray(self.bias), (1,) * (outputs.ndim - 1) + (-1,)
            )

        return outputs

3.5.1 __call__(inputs)

Perform a forward pass using current weights and bias. Samples new parameters if the layer is not frozen.

Parameters:

Name Type Description Default
inputs Array

Input array with shape [*, input_size].

required

Returns:

Type Description
Array

Output array with shape [*, output_size].

Raises:

Type Description
ValueError

If the layer is frozen but weights or bias are undefined.

Source code in illia/nn/jax/linear.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    Perform a forward pass using current weights and bias. Samples new
    parameters if the layer is not frozen.

    Args:
        inputs: Input array with shape [*, input_size].

    Returns:
        Output array with shape [*, output_size].

    Raises:
        ValueError: If the layer is frozen but weights or bias are
            undefined.
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias only if using bias
        if self.use_bias and self.bias_distribution is not None:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    elif self.weights is None or (self.use_bias and self.bias is None):
        raise ValueError(
            "Module has been frozen with undefined weights and/or bias."
        )

    # Compute outputs
    outputs = inputs @ self.weights.T

    # Add bias only if using bias
    if self.use_bias and self.bias is not None:
        outputs += jnp.reshape(
            jnp.asarray(self.bias), (1,) * (outputs.ndim - 1) + (-1,)
        )

    return outputs

3.5.2 __init__(input_size, output_size, weights_distribution=None, bias_distribution=None, use_bias=True, precision=None, dot_general=lax.dot_general, rngs=nnx.Rngs(0), **kwargs)

Initialize a Bayesian linear layer with optional priors for weights and bias. Samples initial parameter values from the specified distributions.

Parameters:

Name Type Description Default
input_size int

Number of input features.

required
output_size int

Number of output features.

required
weights_distribution Optional[GaussianDistribution]

Distribution for weights.

None
bias_distribution Optional[GaussianDistribution]

Distribution for bias.

None
use_bias bool

Whether to include a bias term.

True
precision PrecisionLike

Precision for dot product computations.

None
dot_general DotGeneralT

Function for generalized dot products.

dot_general
rngs Rngs

Random number generators for reproducibility.

Rngs(0)
**kwargs Any

Extra arguments passed to the base class.

{}

Returns:

Type Description
None

None.

Notes

Gaussian distributions are used by default if none are provided.

Source code in illia/nn/jax/linear.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    input_size: int,
    output_size: int,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
    use_bias: bool = True,
    precision: PrecisionLike = None,
    dot_general: DotGeneralT = lax.dot_general,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Initialize a Bayesian linear layer with optional priors for weights
    and bias. Samples initial parameter values from the specified
    distributions.

    Args:
        input_size: Number of input features.
        output_size: Number of output features.
        weights_distribution: Distribution for weights.
        bias_distribution: Distribution for bias.
        use_bias: Whether to include a bias term.
        precision: Precision for dot product computations.
        dot_general: Function for generalized dot products.
        rngs: Random number generators for reproducibility.
        **kwargs: Extra arguments passed to the base class.

    Returns:
        None.

    Notes:
        Gaussian distributions are used by default if none are
        provided.
    """

    super().__init__(**kwargs)

    self.input_size = input_size
    self.output_size = output_size
    self.use_bias = use_bias
    self.precision = precision
    self.dot_general = dot_general
    self.rngs = rngs

    # Set weights prior
    if weights_distribution is None:
        self.weights_distribution = GaussianDistribution(
            shape=(self.output_size, self.input_size), rngs=self.rngs
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias prior
    if self.use_bias:
        if bias_distribution is None:
            self.bias_distribution = GaussianDistribution(
                shape=(self.output_size,), rngs=self.rngs
            )
        else:
            self.bias_distribution = bias_distribution
    else:
        self.bias_distribution = None

    # Sample initial weights
    self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample initial bias only if using bias
    if self.use_bias and self.bias_distribution is not None:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    else:
        self.bias = None

3.5.3 freeze()

Freeze the module's parameters to stop gradient computation. If weights or biases are not sampled yet, they are sampled first. Once frozen, parameters are not resampled or updated.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/linear.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def freeze(self) -> None:
    """
    Freeze the module's parameters to stop gradient computation.
    If weights or biases are not sampled yet, they are sampled first.
    Once frozen, parameters are not resampled or updated.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample bias if they are undefined and bias is used
    if self.use_bias and self.bias is None and self.bias_distribution is not None:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Stop gradient computation
    self.weights = jax.lax.stop_gradient(self.weights)
    if self.use_bias:
        self.bias = jax.lax.stop_gradient(self.bias)

3.5.4 kl_cost()

Compute the KL divergence cost for all Bayesian parameters.

Returns:

Type Description
tuple[Array, int]

tuple[jax.Array, int]: A tuple containing the KL divergence cost and the total number of parameters in the layer.

Source code in illia/nn/jax/linear.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Compute the KL divergence cost for all Bayesian parameters.

    Returns:
        tuple[jax.Array, int]: A tuple containing the KL divergence
            cost and the total number of parameters in the layer.
    """

    # Compute log probs for weights
    log_probs: jax.Array = self.weights_distribution.log_prob(
        jnp.asarray(self.weights)
    )

    # Add bias log probs only if using bias
    if (
        self.use_bias
        and self.bias is not None
        and self.bias_distribution is not None
    ):
        log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

    # Compute number of parameters
    num_params: int = self.weights_distribution.num_params
    if self.use_bias and self.bias_distribution is not None:
        num_params += self.bias_distribution.num_params

    return log_probs, num_params

3.6 LSTM

Bayesian LSTM layer with embedding and probabilistic weights. All weights and biases are treated as random variables sampled from Gaussian distributions. Freezing the layer fixes parameters and stops gradient computation.

Source code in illia/nn/jax/lstm.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
class LSTM(BayesianModule):
    """
    Bayesian LSTM layer with embedding and probabilistic weights.
    All weights and biases are treated as random variables sampled from
    Gaussian distributions. Freezing the layer fixes parameters and
    stops gradient computation.
    """

    def __init__(
        self,
        num_embeddings: int,
        embeddings_dim: int,
        hidden_size: int,
        output_size: int,
        padding_idx: Optional[int] = None,
        max_norm: Optional[float] = None,
        norm_type: float = 2.0,
        scale_grad_by_freq: bool = False,
        rngs: Rngs = nnx.Rngs(0),
        **kwargs: Any,
    ) -> None:
        """
        Initialize a Bayesian LSTM layer with embedding and probabilistic
        weights. Sets up all gate distributions and samples initial weights.

        Args:
            num_embeddings: Vocabulary size.
            embeddings_dim: Dimension of token embeddings.
            hidden_size: Number of units in LSTM hidden state.
            output_size: Size of the output layer.
            padding_idx: Index in embeddings to ignore (optional).
            max_norm: Maximum norm for embeddings (optional).
            norm_type: p-norm for max_norm computation.
            scale_grad_by_freq: Scale gradients by token frequency.
            rngs: Random number generators for reproducibility.
            **kwargs: Extra arguments passed to the base class.

        Returns:
            None.

        Notes:
            Gaussian distributions are used by default if none are
            provided.
        """

        super().__init__(**kwargs)

        self.num_embeddings = num_embeddings
        self.embeddings_dim = embeddings_dim
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.padding_idx = padding_idx
        self.max_norm = max_norm
        self.norm_type = norm_type
        self.scale_grad_by_freq = scale_grad_by_freq
        self.rngs = rngs

        # Define the Embedding layer
        self.embedding = Embedding(
            num_embeddings=self.num_embeddings,
            embeddings_dim=self.embeddings_dim,
            padding_idx=self.padding_idx,
            max_norm=self.max_norm,
            norm_type=self.norm_type,
            scale_grad_by_freq=self.scale_grad_by_freq,
            rngs=self.rngs,
        )

        # Initialize weights
        # Forget gate
        self.wf_distribution = GaussianDistribution(
            (self.hidden_size, self.embeddings_dim + self.hidden_size)
        )
        self.bf_distribution = GaussianDistribution((self.hidden_size,))

        # Input gate
        self.wi_distribution = GaussianDistribution(
            (self.hidden_size, self.embeddings_dim + self.hidden_size)
        )
        self.bi_distribution = GaussianDistribution((self.hidden_size,))

        # Candidate gate
        self.wc_distribution = GaussianDistribution(
            (self.hidden_size, self.embeddings_dim + self.hidden_size)
        )
        self.bc_distribution = GaussianDistribution((self.hidden_size,))

        # Output gate
        self.wo_distribution = GaussianDistribution(
            (self.hidden_size, self.embeddings_dim + self.hidden_size)
        )
        self.bo_distribution = GaussianDistribution((self.hidden_size,))

        # Final gate
        self.wv_distribution = GaussianDistribution(
            (self.output_size, self.hidden_size)
        )
        self.bv_distribution = GaussianDistribution((self.output_size,))

        # Sample initial weights and register buffers
        # Forget gate
        self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
        self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))

        # Input gate
        self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
        self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))

        # Candidate gate
        self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
        self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))

        # Output gate
        self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
        self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))

        # Final output layer
        self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
        self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))

    def freeze(self) -> None:
        """
        Freeze the module's parameters to stop gradient computation.
        If weights or biases are not sampled yet, they are sampled first.
        Once frozen, parameters are not resampled or updated.

        Returns:
            None.
        """

        # Set indicator
        self.frozen = True

        # Freeze embedding layer
        self.embedding.freeze()

        # Forget gate
        if self.wf is None:
            self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
        if self.bf is None:
            self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))
        self.wf = jax.lax.stop_gradient(self.wf)
        self.bf = jax.lax.stop_gradient(self.bf)

        # Input gate
        if self.wi is None:
            self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
        if self.bi is None:
            self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))
        self.wi = jax.lax.stop_gradient(self.wi)
        self.bi = jax.lax.stop_gradient(self.bi)

        # Candidate gate
        if self.wc is None:
            self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
        if self.bc is None:
            self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))
        self.wc = jax.lax.stop_gradient(self.wc)
        self.bc = jax.lax.stop_gradient(self.bc)

        # Output gate
        if self.wo is None:
            self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
        if self.bo is None:
            self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))
        self.wo = jax.lax.stop_gradient(self.wo)
        self.bo = jax.lax.stop_gradient(self.bo)

        # Final output layer
        if self.wv is None:
            self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
        if self.bv is None:
            self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))
        self.wv = jax.lax.stop_gradient(self.wv)
        self.bv = jax.lax.stop_gradient(self.bv)

    def kl_cost(self) -> tuple[jax.Array, int]:
        """
        Compute the KL divergence cost for all Bayesian parameters.

        Returns:
            tuple[jax.Array, int]: A tuple containing the KL divergence
                cost and the total number of parameters in the layer.
        """

        # Compute log probs for each pair of weights and bias
        log_probs_f = self.wf_distribution.log_prob(
            jnp.asarray(self.wf)
        ) + self.bf_distribution.log_prob(jnp.asarray(self.bf))
        log_probs_i = self.wi_distribution.log_prob(
            jnp.asarray(self.wi)
        ) + self.bi_distribution.log_prob(jnp.asarray(self.bi))
        log_probs_c = self.wc_distribution.log_prob(
            jnp.asarray(self.wc)
        ) + self.bc_distribution.log_prob(jnp.asarray(self.bc))
        log_probs_o = self.wo_distribution.log_prob(
            jnp.asarray(self.wo)
        ) + self.bo_distribution.log_prob(jnp.asarray(self.bo))
        log_probs_v = self.wv_distribution.log_prob(
            jnp.asarray(self.wv)
        ) + self.bv_distribution.log_prob(jnp.asarray(self.bv))

        # Compute the total loss
        log_probs = log_probs_f + log_probs_i + log_probs_c + log_probs_o + log_probs_v

        # Compute number of parameters
        num_params = (
            self.wf_distribution.num_params
            + self.bf_distribution.num_params
            + self.wi_distribution.num_params
            + self.bi_distribution.num_params
            + self.wc_distribution.num_params
            + self.bc_distribution.num_params
            + self.wo_distribution.num_params
            + self.bo_distribution.num_params
            + self.wv_distribution.num_params
            + self.bv_distribution.num_params
        )

        return log_probs, num_params

    def __call__(
        self,
        inputs: jax.Array,
        init_states: Optional[tuple[jax.Array, jax.Array]] = None,
    ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
        """
        Perform a forward pass through the Bayesian LSTM layer.

        Args:
            inputs: Token indices with shape [batch, seq_len, 1].
            init_states: Optional tuple of initial hidden and cell states.

        Returns:
            Tuple containing:
            - Output tensor of shape [batch, output_size].
            - Tuple of (hidden_state, cell_state) after sequence processing.

        Raises:
            ValueError: If the layer is frozen but weights are
                undefined.
        """

        # Sample weights if not frozen
        if not self.frozen:
            self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
            self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))
            self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
            self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))
            self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
            self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))
            self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
            self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))
            self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
            self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))
        elif any(
            p is None
            for p in [
                self.wf,
                self.bf,
                self.wi,
                self.bi,
                self.wc,
                self.bc,
                self.wo,
                self.bo,
                self.wv,
                self.bv,
            ]
        ):
            raise ValueError(
                "Module has been frozen with undefined weights and/or bias."
            )

        # Apply embedding layer to input indices
        inputs = jnp.squeeze(inputs, axis=-1)
        inputs = self.embedding(inputs)
        batch_size = jnp.shape(inputs)[0]
        seq_len = jnp.shape(inputs)[1]

        # Initialize h_t and c_t if init_states is None
        if init_states is None:
            h_t = jnp.zeros([batch_size, self.hidden_size])
            c_t = jnp.zeros([batch_size, self.hidden_size])
        else:
            h_t, c_t = init_states[0], init_states[1]

        # Process sequence
        for t in range(seq_len):
            # Shape: (batch_size, embedding_dim)
            x_t = inputs[:, t, :]

            # Concatenate input and hidden state
            # Shape: (batch_size, embedding_dim + hidden_size)
            z_t = jnp.concat([x_t, h_t], axis=1)

            # Forget gate
            ft = nnx.sigmoid(z_t @ self.wf.T + self.bf)

            # Input gate
            it = nnx.sigmoid(z_t @ self.wi.T + self.bi)

            # Candidate cell state
            can = nnx.tanh(z_t @ self.wc.T + self.bc)

            # Output gate
            ot = nnx.sigmoid(z_t @ self.wo.T + self.bo)

            # Update cell state
            c_t = c_t * ft + can * it

            # Update hidden state
            h_t = ot * nnx.tanh(c_t)

        # Compute final output
        y_t = h_t @ self.wv.T + self.bv

        return y_t, (h_t, c_t)

3.6.1 __call__(inputs, init_states=None)

Perform a forward pass through the Bayesian LSTM layer.

Parameters:

Name Type Description Default
inputs Array

Token indices with shape [batch, seq_len, 1].

required
init_states Optional[tuple[Array, Array]]

Optional tuple of initial hidden and cell states.

None

Returns:

Type Description
Array

Tuple containing:

tuple[Array, Array]
  • Output tensor of shape [batch, output_size].
tuple[Array, tuple[Array, Array]]
  • Tuple of (hidden_state, cell_state) after sequence processing.

Raises:

Type Description
ValueError

If the layer is frozen but weights are undefined.

Source code in illia/nn/jax/lstm.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def __call__(
    self,
    inputs: jax.Array,
    init_states: Optional[tuple[jax.Array, jax.Array]] = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
    """
    Perform a forward pass through the Bayesian LSTM layer.

    Args:
        inputs: Token indices with shape [batch, seq_len, 1].
        init_states: Optional tuple of initial hidden and cell states.

    Returns:
        Tuple containing:
        - Output tensor of shape [batch, output_size].
        - Tuple of (hidden_state, cell_state) after sequence processing.

    Raises:
        ValueError: If the layer is frozen but weights are
            undefined.
    """

    # Sample weights if not frozen
    if not self.frozen:
        self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
        self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))
        self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
        self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))
        self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
        self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))
        self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
        self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))
        self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
        self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))
    elif any(
        p is None
        for p in [
            self.wf,
            self.bf,
            self.wi,
            self.bi,
            self.wc,
            self.bc,
            self.wo,
            self.bo,
            self.wv,
            self.bv,
        ]
    ):
        raise ValueError(
            "Module has been frozen with undefined weights and/or bias."
        )

    # Apply embedding layer to input indices
    inputs = jnp.squeeze(inputs, axis=-1)
    inputs = self.embedding(inputs)
    batch_size = jnp.shape(inputs)[0]
    seq_len = jnp.shape(inputs)[1]

    # Initialize h_t and c_t if init_states is None
    if init_states is None:
        h_t = jnp.zeros([batch_size, self.hidden_size])
        c_t = jnp.zeros([batch_size, self.hidden_size])
    else:
        h_t, c_t = init_states[0], init_states[1]

    # Process sequence
    for t in range(seq_len):
        # Shape: (batch_size, embedding_dim)
        x_t = inputs[:, t, :]

        # Concatenate input and hidden state
        # Shape: (batch_size, embedding_dim + hidden_size)
        z_t = jnp.concat([x_t, h_t], axis=1)

        # Forget gate
        ft = nnx.sigmoid(z_t @ self.wf.T + self.bf)

        # Input gate
        it = nnx.sigmoid(z_t @ self.wi.T + self.bi)

        # Candidate cell state
        can = nnx.tanh(z_t @ self.wc.T + self.bc)

        # Output gate
        ot = nnx.sigmoid(z_t @ self.wo.T + self.bo)

        # Update cell state
        c_t = c_t * ft + can * it

        # Update hidden state
        h_t = ot * nnx.tanh(c_t)

    # Compute final output
    y_t = h_t @ self.wv.T + self.bv

    return y_t, (h_t, c_t)

3.6.2 __init__(num_embeddings, embeddings_dim, hidden_size, output_size, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, rngs=nnx.Rngs(0), **kwargs)

Initialize a Bayesian LSTM layer with embedding and probabilistic weights. Sets up all gate distributions and samples initial weights.

Parameters:

Name Type Description Default
num_embeddings int

Vocabulary size.

required
embeddings_dim int

Dimension of token embeddings.

required
hidden_size int

Number of units in LSTM hidden state.

required
output_size int

Size of the output layer.

required
padding_idx Optional[int]

Index in embeddings to ignore (optional).

None
max_norm Optional[float]

Maximum norm for embeddings (optional).

None
norm_type float

p-norm for max_norm computation.

2.0
scale_grad_by_freq bool

Scale gradients by token frequency.

False
rngs Rngs

Random number generators for reproducibility.

Rngs(0)
**kwargs Any

Extra arguments passed to the base class.

{}

Returns:

Type Description
None

None.

Notes

Gaussian distributions are used by default if none are provided.

Source code in illia/nn/jax/lstm.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def __init__(
    self,
    num_embeddings: int,
    embeddings_dim: int,
    hidden_size: int,
    output_size: int,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: float = 2.0,
    scale_grad_by_freq: bool = False,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Initialize a Bayesian LSTM layer with embedding and probabilistic
    weights. Sets up all gate distributions and samples initial weights.

    Args:
        num_embeddings: Vocabulary size.
        embeddings_dim: Dimension of token embeddings.
        hidden_size: Number of units in LSTM hidden state.
        output_size: Size of the output layer.
        padding_idx: Index in embeddings to ignore (optional).
        max_norm: Maximum norm for embeddings (optional).
        norm_type: p-norm for max_norm computation.
        scale_grad_by_freq: Scale gradients by token frequency.
        rngs: Random number generators for reproducibility.
        **kwargs: Extra arguments passed to the base class.

    Returns:
        None.

    Notes:
        Gaussian distributions are used by default if none are
        provided.
    """

    super().__init__(**kwargs)

    self.num_embeddings = num_embeddings
    self.embeddings_dim = embeddings_dim
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.padding_idx = padding_idx
    self.max_norm = max_norm
    self.norm_type = norm_type
    self.scale_grad_by_freq = scale_grad_by_freq
    self.rngs = rngs

    # Define the Embedding layer
    self.embedding = Embedding(
        num_embeddings=self.num_embeddings,
        embeddings_dim=self.embeddings_dim,
        padding_idx=self.padding_idx,
        max_norm=self.max_norm,
        norm_type=self.norm_type,
        scale_grad_by_freq=self.scale_grad_by_freq,
        rngs=self.rngs,
    )

    # Initialize weights
    # Forget gate
    self.wf_distribution = GaussianDistribution(
        (self.hidden_size, self.embeddings_dim + self.hidden_size)
    )
    self.bf_distribution = GaussianDistribution((self.hidden_size,))

    # Input gate
    self.wi_distribution = GaussianDistribution(
        (self.hidden_size, self.embeddings_dim + self.hidden_size)
    )
    self.bi_distribution = GaussianDistribution((self.hidden_size,))

    # Candidate gate
    self.wc_distribution = GaussianDistribution(
        (self.hidden_size, self.embeddings_dim + self.hidden_size)
    )
    self.bc_distribution = GaussianDistribution((self.hidden_size,))

    # Output gate
    self.wo_distribution = GaussianDistribution(
        (self.hidden_size, self.embeddings_dim + self.hidden_size)
    )
    self.bo_distribution = GaussianDistribution((self.hidden_size,))

    # Final gate
    self.wv_distribution = GaussianDistribution(
        (self.output_size, self.hidden_size)
    )
    self.bv_distribution = GaussianDistribution((self.output_size,))

    # Sample initial weights and register buffers
    # Forget gate
    self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
    self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))

    # Input gate
    self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
    self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))

    # Candidate gate
    self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
    self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))

    # Output gate
    self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
    self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))

    # Final output layer
    self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
    self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))

3.6.3 freeze()

Freeze the module's parameters to stop gradient computation. If weights or biases are not sampled yet, they are sampled first. Once frozen, parameters are not resampled or updated.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/lstm.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def freeze(self) -> None:
    """
    Freeze the module's parameters to stop gradient computation.
    If weights or biases are not sampled yet, they are sampled first.
    Once frozen, parameters are not resampled or updated.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Freeze embedding layer
    self.embedding.freeze()

    # Forget gate
    if self.wf is None:
        self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
    if self.bf is None:
        self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))
    self.wf = jax.lax.stop_gradient(self.wf)
    self.bf = jax.lax.stop_gradient(self.bf)

    # Input gate
    if self.wi is None:
        self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
    if self.bi is None:
        self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))
    self.wi = jax.lax.stop_gradient(self.wi)
    self.bi = jax.lax.stop_gradient(self.bi)

    # Candidate gate
    if self.wc is None:
        self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
    if self.bc is None:
        self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))
    self.wc = jax.lax.stop_gradient(self.wc)
    self.bc = jax.lax.stop_gradient(self.bc)

    # Output gate
    if self.wo is None:
        self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
    if self.bo is None:
        self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))
    self.wo = jax.lax.stop_gradient(self.wo)
    self.bo = jax.lax.stop_gradient(self.bo)

    # Final output layer
    if self.wv is None:
        self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
    if self.bv is None:
        self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))
    self.wv = jax.lax.stop_gradient(self.wv)
    self.bv = jax.lax.stop_gradient(self.bv)

3.6.4 kl_cost()

Compute the KL divergence cost for all Bayesian parameters.

Returns:

Type Description
tuple[Array, int]

tuple[jax.Array, int]: A tuple containing the KL divergence cost and the total number of parameters in the layer.

Source code in illia/nn/jax/lstm.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Compute the KL divergence cost for all Bayesian parameters.

    Returns:
        tuple[jax.Array, int]: A tuple containing the KL divergence
            cost and the total number of parameters in the layer.
    """

    # Compute log probs for each pair of weights and bias
    log_probs_f = self.wf_distribution.log_prob(
        jnp.asarray(self.wf)
    ) + self.bf_distribution.log_prob(jnp.asarray(self.bf))
    log_probs_i = self.wi_distribution.log_prob(
        jnp.asarray(self.wi)
    ) + self.bi_distribution.log_prob(jnp.asarray(self.bi))
    log_probs_c = self.wc_distribution.log_prob(
        jnp.asarray(self.wc)
    ) + self.bc_distribution.log_prob(jnp.asarray(self.bc))
    log_probs_o = self.wo_distribution.log_prob(
        jnp.asarray(self.wo)
    ) + self.bo_distribution.log_prob(jnp.asarray(self.bo))
    log_probs_v = self.wv_distribution.log_prob(
        jnp.asarray(self.wv)
    ) + self.bv_distribution.log_prob(jnp.asarray(self.bv))

    # Compute the total loss
    log_probs = log_probs_f + log_probs_i + log_probs_c + log_probs_o + log_probs_v

    # Compute number of parameters
    num_params = (
        self.wf_distribution.num_params
        + self.bf_distribution.num_params
        + self.wi_distribution.num_params
        + self.bi_distribution.num_params
        + self.wc_distribution.num_params
        + self.bc_distribution.num_params
        + self.wo_distribution.num_params
        + self.bo_distribution.num_params
        + self.wv_distribution.num_params
        + self.bv_distribution.num_params
    )

    return log_probs, num_params