Skip to content

3. Neural Network

This module contains the code for the bayesian Conv1d.

3.1 Conv1d(input_channels, output_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, weights_distribution=None, bias_distribution=None, use_bias=True, rngs=nnx.Rngs(0), **kwargs)

This class is the bayesian implementation of the Conv1d class.

Definition of a Bayesian Convolution 1D layer.

Parameters:

Name Type Description Default
input_channels int

Number of input feature channels.

required
output_channels int

Number of output feature channels.

required
kernel_size int

Size of the convolutional kernel.

required
stride int

Stride of the convolution operation.

1
padding int

Amount of zero-padding added to both sides.

0
dilation int

Spacing between kernel elements.

1
groups int

Number of blocked connections between input and output.

1
weights_distribution Optional[GaussianDistribution]

Distribution to initialize weights.

None
bias_distribution Optional[GaussianDistribution]

Distribution to initialize bias.

None
use_bias bool

Whether to include a bias term.

True
rngs Rngs

Random number generators for reproducibility.

Rngs(0)

Returns:

Type Description
None

None.

Source code in illia/nn/jax/conv1d.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __init__(
    self,
    input_channels: int,
    output_channels: int,
    kernel_size: int,
    stride: int = 1,
    padding: int = 0,
    dilation: int = 1,
    groups: int = 1,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
    use_bias: bool = True,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Definition of a Bayesian Convolution 1D layer.

    Args:
        input_channels: Number of input feature channels.
        output_channels: Number of output feature channels.
        kernel_size: Size of the convolutional kernel.
        stride: Stride of the convolution operation.
        padding: Amount of zero-padding added to both sides.
        dilation: Spacing between kernel elements.
        groups: Number of blocked connections between input and output.
        weights_distribution: Distribution to initialize weights.
        bias_distribution: Distribution to initialize bias.
        use_bias: Whether to include a bias term.
        rngs: Random number generators for reproducibility.

    Returns:
        None.
    """

    # Call super class constructor
    super().__init__(**kwargs)

    # Set attributes
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.dilation = dilation
    self.groups = groups
    self.use_bias = use_bias
    self.rngs = rngs

    # Set weights prior
    if weights_distribution is None:
        self.weights_distribution = GaussianDistribution(
            shape=(
                self.output_channels,
                self.input_channels // self.groups,
                self.kernel_size,
            ),
            rngs=self.rngs,
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias prior
    if self.use_bias:
        if bias_distribution is None:
            self.bias_distribution = GaussianDistribution(
                shape=(self.output_channels,),
                rngs=self.rngs,
            )
        else:
            self.bias_distribution = bias_distribution
    else:
        self.bias_distribution = None

    # Sample initial weights
    self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample initial bias only if using bias
    if self.use_bias and self.bias_distribution:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    else:
        self.bias = None

3.1.1 __call__(inputs)

Applies the convolution operation to the inputs using current weights and bias. If the model is not frozen, samples new weights and bias before computation.

Parameters:

Name Type Description Default
inputs Array

Input array to be convolved.

required

Returns:

Type Description
Array

Output array after applying convolution and bias.

Source code in illia/nn/jax/conv1d.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    Applies the convolution operation to the inputs using current weights
    and bias. If the model is not frozen, samples new weights and bias
    before computation.

    Args:
        inputs: Input array to be convolved.

    Returns:
        Output array after applying convolution and bias.
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias only if using bias
        if self.bias is None and self.use_bias and self.bias_distribution:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Compute outputs
    outputs = jax.lax.conv_general_dilated(
        lhs=inputs,
        rhs=jnp.asarray(self.weights),
        window_strides=[self.stride],
        padding=[(self.padding, self.padding)],
        lhs_dilation=[1],
        rhs_dilation=[self.dilation],
        dimension_numbers=(
            "NCH",  # Input
            "OIH",  # Kernel
            "NCH",  # Output
        ),
        feature_group_count=self.groups,
    )

    # Add bias only if using bias
    if self.use_bias and self.bias is not None:
        outputs += jnp.reshape(
            a=jnp.asarray(self.bias), shape=(1, self.output_channels, 1)
        )

    return outputs

3.1.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/conv1d.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample bias if they are undefined and bias is used
    if self.use_bias and self.bias is None and self.bias_distribution:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Stop gradient computation
    self.weights = jax.lax.stop_gradient(self.weights)
    if self.use_bias and self.bias:
        self.bias = jax.lax.stop_gradient(self.bias)

3.1.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Array

Tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/jax/conv1d.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        Tuple containing KL divergence cost and total number of
        parameters.
    """

    # Compute log probs for weights
    log_probs: jax.Array = self.weights_distribution.log_prob(
        jnp.asarray(self.weights)
    )

    # Add bias log probs only if using bias
    if self.use_bias and self.bias and self.bias_distribution:
        log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

    # Compute number of parameters
    num_params: int = self.weights_distribution.num_params
    if self.use_bias and self.bias_distribution:
        num_params += self.bias_distribution.num_params

    return log_probs, num_params

This module contains the code for the bayesian Conv2d.

3.2 Conv2d(input_channels, output_channels, kernel_size, stride=(1, 1), padding=(0, 0), dilation=(1, 1), groups=1, weights_distribution=None, bias_distribution=None, use_bias=True, rngs=nnx.Rngs(0), **kwargs)

This class is the bayesian implementation of the Conv2d class.

Definition of a Bayesian Convolution 2D layer.

Parameters:

Name Type Description Default
input_channels int

Number of input feature channels.

required
output_channels int

Number of output feature channels.

required
kernel_size int | tuple[int, int]

Size of the convolutional kernel.

required
stride tuple[int, int]

Stride of the convolution operation.

(1, 1)
padding tuple[int, int]

Tuple for zero-padding on both sides.

(0, 0)
dilation tuple[int, int]

Spacing between kernel elements.

(1, 1)
groups int

Number of blocked connections between input and output.

1
weights_distribution Optional[GaussianDistribution]

Distribution to initialize weights.

None
bias_distribution Optional[GaussianDistribution]

Distribution to initialize bias.

None
use_bias bool

Whether to include a bias term.

True
rngs Rngs

Random number generators for reproducibility.

Rngs(0)

Returns:

Type Description
None

None.

Source code in illia/nn/jax/conv2d.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def __init__(
    self,
    input_channels: int,
    output_channels: int,
    kernel_size: int | tuple[int, int],
    stride: tuple[int, int] = (1, 1),
    padding: tuple[int, int] = (0, 0),
    dilation: tuple[int, int] = (1, 1),
    groups: int = 1,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
    use_bias: bool = True,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    Definition of a Bayesian Convolution 2D layer.

    Args:
        input_channels: Number of input feature channels.
        output_channels: Number of output feature channels.
        kernel_size: Size of the convolutional kernel.
        stride: Stride of the convolution operation.
        padding: Tuple for zero-padding on both sides.
        dilation: Spacing between kernel elements.
        groups: Number of blocked connections between input and output.
        weights_distribution: Distribution to initialize weights.
        bias_distribution: Distribution to initialize bias.
        use_bias: Whether to include a bias term.
        rngs: Random number generators for reproducibility.

    Returns:
        None.
    """

    # Call super class constructor
    super().__init__(**kwargs)

    # Set attributes
    self.input_channels = input_channels
    self.output_channels = output_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.dilation = dilation
    self.groups = groups
    self.use_bias = use_bias
    self.rngs = rngs

    # Set weights distribution
    if weights_distribution is None:
        # Extend kernel if we only have 1 value
        if isinstance(self.kernel_size, int):
            self.kernel_size = (self.kernel_size, self.kernel_size)

        self.weights_distribution: GaussianDistribution = GaussianDistribution(
            shape=(
                self.output_channels,
                self.input_channels // self.groups,
                *self.kernel_size,
            ),
            rngs=self.rngs,
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias prior
    if self.use_bias:
        if bias_distribution is None:
            # Define weights distribution
            self.bias_distribution = GaussianDistribution(
                shape=(self.output_channels,),
                rngs=self.rngs,
            )
        else:
            self.bias_distribution = bias_distribution
    else:
        self.bias_distribution = None

    # Sample initial weights
    self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample initial bias only if using bias
    if self.use_bias and self.bias_distribution:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    else:
        self.bias = None

3.2.1 __call__(inputs)

Applies the convolution operation to the inputs using current weights and bias. If the model is not frozen, samples new weights and bias before computation.

Parameters:

Name Type Description Default
inputs Array

Input array to be convolved.

required

Returns:

Type Description
Array

Output array after applying convolution and bias.

Source code in illia/nn/jax/conv2d.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    Applies the convolution operation to the inputs using current weights
    and bias. If the model is not frozen, samples new weights and bias
    before computation.

    Args:
        inputs: Input array to be convolved.

    Returns:
        Output array after applying convolution and bias.
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias only if using bias
        if self.bias is None and self.use_bias and self.bias_distribution:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Compute ouputs
    outputs = jax.lax.conv_general_dilated(
        lhs=inputs,
        rhs=jnp.asarray(self.weights),
        window_strides=self.stride,
        padding=[self.padding, self.padding],
        lhs_dilation=[1, 1],
        rhs_dilation=self.dilation,
        dimension_numbers=(
            "NCHW",  # Input
            "OIHW",  # Kernel
            "NCHW",  # Output
        ),
        feature_group_count=self.groups,
    )

    # Add bias only if using bias
    if self.use_bias and self.bias:
        outputs += jnp.reshape(
            a=jnp.asarray(self.bias), shape=(1, self.output_channels, 1, 1)
        )

    return outputs

3.2.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/conv2d.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample bias if they are undefined and bias is used
    if self.use_bias and self.bias is None and self.bias_distribution:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Stop gradient computation
    self.weights = jax.lax.stop_gradient(self.weights)
    if self.use_bias and self.bias:
        self.bias = jax.lax.stop_gradient(self.bias)

3.2.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Array

Tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/jax/conv2d.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        Tuple containing KL divergence cost and total number of
        parameters.
    """

    # Compute log probs for weights
    log_probs: jax.Array = self.weights_distribution.log_prob(
        jnp.asarray(self.weights)
    )

    # Add bias log probs only if using bias
    if self.use_bias and self.bias and self.bias_distribution:
        log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

    # Compute number of parameters
    num_params: int = self.weights_distribution.num_params
    if self.use_bias and self.bias_distribution:
        num_params += self.bias_distribution.num_params

    return log_probs, num_params

This module contains the code for bayesian Embedding layer.

3.3 Embedding(num_embeddings, embeddings_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, weights_distribution=None, rngs=nnx.Rngs(0), **kwargs)

This class is the bayesian implementation of the Embedding class.

This method is the constructor of the embedding class.

Parameters:

Name Type Description Default
num_embeddings int

size of the dictionary of embeddings.

required
embeddings_dim int

the size of each embedding vector.

required
padding_idx Optional[int]

If specified, the entries at padding_idx do not contribute to the gradient.

None
max_norm Optional[float]

If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm.

None
norm_type float

The p of the p-norm to compute for the max_norm option.

2.0
scale_grad_by_freq bool

If given, this will scale gradients by the inverse of frequency of the words in the mini-batch.

False
weights_distribution Optional[GaussianDistribution]

distribution for the weights of the layer.

None
rngs Rngs

Random number generators for reproducibility.

Rngs(0)

Returns:

Type Description
None

None.

Source code in illia/nn/jax/embedding.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(
    self,
    num_embeddings: int,
    embeddings_dim: int,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: float = 2.0,
    scale_grad_by_freq: bool = False,
    weights_distribution: Optional[GaussianDistribution] = None,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    This method is the constructor of the embedding class.

    Args:
        num_embeddings: size of the dictionary of embeddings.
        embeddings_dim: the size of each embedding vector.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient.
        max_norm: If given, each embedding vector with norm larger
            than max_norm is renormalized to have norm max_norm.
        norm_type: The p of the p-norm to compute for the max_norm
            option.
        scale_grad_by_freq: If given, this will scale gradients by
            the inverse of frequency of the words in the
            mini-batch.
        weights_distribution: distribution for the weights of the
            layer.
        rngs: Random number generators for reproducibility.

    Returns:
        None.
    """

    # Call super class constructor
    super().__init__(**kwargs)

    # Set attributes
    self.num_embeddings = num_embeddings
    self.embeddings_dim = embeddings_dim
    self.padding_idx = padding_idx
    self.max_norm = max_norm
    self.norm_type = norm_type
    self.scale_grad_by_freq = scale_grad_by_freq
    self.rngs = rngs

    # Set weights distribution
    if weights_distribution is None:
        self.weights_distribution = GaussianDistribution(
            (self.num_embeddings, self.embeddings_dim)
        )
    else:
        self.weights_distribution = weights_distribution

    # Sample initial weights
    self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

3.3.1 __call__(inputs)

This method is the forward pass of the layer.

Parameters:

Name Type Description Default
inputs Array

input tensor. Dimensions: [*].

required

Returns:

Type Description
Array

outputs tensor. Dimension: [*, embedding dim].

Source code in illia/nn/jax/embedding.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    This method is the forward pass of the layer.

    Args:
        inputs: input tensor. Dimensions: [*].

    Returns:
        outputs tensor. Dimension: [*, embedding dim].
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Perform embedding lookup
    outputs = self.weights.value[inputs]

    # Apply padding_idx
    if self.padding_idx is not None:
        # Create mask for padding indices
        mask = inputs == self.padding_idx
        # Zero out embeddings for padding indices
        outputs = jnp.where(mask[..., None], 0.0, outputs)

    # Apply max_norm
    if self.max_norm is not None:
        norms = jnp.linalg.norm(outputs, axis=-1, ord=self.norm_type, keepdims=True)
        # Normalize vectors that exceed max_norm
        scale = jnp.minimum(1.0, self.max_norm / (norms + 1e-8))
        outputs = outputs * scale

    return outputs

3.3.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/embedding.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Stop gradient computation
    self.weights = jax.lax.stop_gradient(self.weights)

3.3.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Array

Tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/jax/embedding.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        Tuple containing KL divergence cost and total number of
        parameters.
    """

    # Compute log probs for weights
    log_probs: jax.Array = self.weights_distribution.log_prob(
        jnp.asarray(self.weights)
    )

    # get number of parameters
    num_params: int = self.weights_distribution.num_params

    return log_probs, num_params

This module contains the code for Linear Bayesian layer.

3.4 Linear(input_size, output_size, weights_distribution=None, bias_distribution=None, use_bias=True, precision=None, dot_general=lax.dot_general, rngs=nnx.Rngs(0), **kwargs)

This class is the bayesian implementation of the Linear class.

This is the constructor of the Linear class.

Parameters:

Name Type Description Default
input_size int

Size of the input features.

required
output_size int

Size of the output features.

required
weights_distribution Optional[GaussianDistribution]

Prior distribution of the weights.

None
bias_distribution Optional[GaussianDistribution]

Prior distribution of the bias.

None
use_bias bool

Whether to include a bias term in the layer.

True
precision PrecisionLike

Precision used in dot product operations.

None
dot_general DotGeneralT

Function for computing generalized dot products.

dot_general

Returns:

Type Description
None

None.

Source code in illia/nn/jax/linear.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def __init__(
    self,
    input_size: int,
    output_size: int,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
    use_bias: bool = True,
    precision: PrecisionLike = None,
    dot_general: DotGeneralT = lax.dot_general,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """
    This is the constructor of the Linear class.

    Args:
        input_size: Size of the input features.
        output_size: Size of the output features.
        weights_distribution: Prior distribution of the weights.
        bias_distribution: Prior distribution of the bias.
        use_bias: Whether to include a bias term in the layer.
        precision: Precision used in dot product operations.
        dot_general: Function for computing generalized dot
            products.

    Returns:
        None.
    """

    # Call super class constructor
    super().__init__(**kwargs)

    # Set attributes
    self.input_size = input_size
    self.output_size = output_size
    self.use_bias = use_bias
    self.precision = precision
    self.dot_general = dot_general
    self.rngs = rngs

    # Set weights prior
    if weights_distribution is None:
        self.weights_distribution = GaussianDistribution(
            shape=(self.output_size, self.input_size), rngs=self.rngs
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias prior
    if self.use_bias:
        if bias_distribution is None:
            self.bias_distribution = GaussianDistribution(
                shape=(self.output_size,), rngs=self.rngs
            )
        else:
            self.bias_distribution = bias_distribution
    else:
        self.bias_distribution = None

    # Sample initial weights
    self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample initial bias only if using bias
    if self.use_bias and self.bias_distribution:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))
    else:
        self.bias = None

3.4.1 __call__(inputs)

This method is the forward pass of the model.

Parameters:

Name Type Description Default
inputs Array

Inputs of the model. Dimensions: [*, input size].

required

Returns:

Type Description
Array

Output tensor. Dimension: [*, output size].

Source code in illia/nn/jax/linear.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    This method is the forward pass of the model.

    Args:
        inputs: Inputs of the model. Dimensions: [*, input size].

    Returns:
        Output tensor. Dimension: [*, output size].
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

        # Sample bias only if using bias
        if self.bias is None and self.use_bias and self.bias_distribution:
            self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Compute outputs
    outputs = inputs @ self.weights.T

    # Add bias only if using bias
    if self.use_bias and self.bias:
        outputs += jnp.reshape(
            jnp.asarray(self.bias), (1,) * (outputs.ndim - 1) + (-1,)
        )

    return outputs

3.4.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/linear.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:
        self.weights = nnx.Param(self.weights_distribution.sample(self.rngs))

    # Sample bias if they are undefined and bias is used
    if self.use_bias and self.bias is None and self.bias_distribution:
        self.bias = nnx.Param(self.bias_distribution.sample(self.rngs))

    # Stop gradient computation
    self.weights = jax.lax.stop_gradient(self.weights)
    if self.use_bias and self.bias:
        self.bias = jax.lax.stop_gradient(self.bias)

3.4.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Array

Tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/jax/linear.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        Tuple containing KL divergence cost and total number of
        parameters.
    """

    # Compute log probs for weights
    log_probs: jax.Array = self.weights_distribution.log_prob(
        jnp.asarray(self.weights)
    )

    # Add bias log probs only if using bias
    if self.use_bias and self.bias and self.bias_distribution:
        log_probs += self.bias_distribution.log_prob(jnp.asarray(self.bias))

    # Compute number of parameters
    num_params: int = self.weights_distribution.num_params
    if self.use_bias and self.bias_distribution:
        num_params += self.bias_distribution.num_params

    return log_probs, num_params

This module contains the code for the bayesian LSTM.

3.5 LSTM(num_embeddings, embeddings_dim, hidden_size, output_size, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, rngs=nnx.Rngs(0), **kwargs)

This class is the bayesian implementation of the TensorFlow LSTM layer.

summary

Parameters:

Name Type Description Default
num_embeddings int

description

required
embeddings_dim int

description

required
hidden_size int

description

required
output_size int

description

required
padding_idx Optional[int]

description. Defaults to None.

None
max_norm Optional[float]

description. Defaults to None.

None
norm_type float

description. Defaults to 2.0.

2.0
scale_grad_by_freq bool

description. Defaults to False.

False
rngs Rngs

description. Defaults to nnx.Rngs(0).

Rngs(0)

Returns:

Type Description
None

None.

Source code in illia/nn/jax/lstm.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def __init__(
    self,
    num_embeddings: int,
    embeddings_dim: int,
    hidden_size: int,
    output_size: int,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: float = 2.0,
    scale_grad_by_freq: bool = False,
    rngs: Rngs = nnx.Rngs(0),
    **kwargs: Any,
) -> None:
    """_summary_

    Args:
        num_embeddings (int): _description_
        embeddings_dim (int): _description_
        hidden_size (int): _description_
        output_size (int): _description_
        padding_idx (Optional[int], optional): _description_. Defaults to None.
        max_norm (Optional[float], optional): _description_. Defaults to None.
        norm_type (float, optional): _description_. Defaults to 2.0.
        scale_grad_by_freq (bool, optional): _description_. Defaults to False.
        rngs (Rngs, optional): _description_. Defaults to nnx.Rngs(0).

    Returns:
        None.
    """

    # Call super-class constructor
    super().__init__(**kwargs)

    # Set attributes
    self.num_embeddings = num_embeddings
    self.embeddings_dim = embeddings_dim
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.padding_idx = padding_idx
    self.max_norm = max_norm
    self.norm_type = norm_type
    self.scale_grad_by_freq = scale_grad_by_freq
    self.rngs = rngs

    # Define the Embedding layer
    self.embedding = Embedding(
        num_embeddings=self.num_embeddings,
        embeddings_dim=self.embeddings_dim,
        padding_idx=self.padding_idx,
        max_norm=self.max_norm,
        norm_type=self.norm_type,
        scale_grad_by_freq=self.scale_grad_by_freq,
        rngs=self.rngs,
    )

    # Initialize weights
    # Forget gate
    self.wf_distribution = GaussianDistribution(
        (self.hidden_size, self.embeddings_dim + self.hidden_size)
    )
    self.bf_distribution = GaussianDistribution((self.hidden_size,))

    # Input gate
    self.wi_distribution = GaussianDistribution(
        (self.hidden_size, self.embeddings_dim + self.hidden_size)
    )
    self.bi_distribution = GaussianDistribution((self.hidden_size,))

    # Candidate gate
    self.wc_distribution = GaussianDistribution(
        (self.hidden_size, self.embeddings_dim + self.hidden_size)
    )
    self.bc_distribution = GaussianDistribution((self.hidden_size,))

    # Output gate
    self.wo_distribution = GaussianDistribution(
        (self.hidden_size, self.embeddings_dim + self.hidden_size)
    )
    self.bo_distribution = GaussianDistribution((self.hidden_size,))

    # Final gate
    self.wv_distribution = GaussianDistribution(
        (self.output_size, self.hidden_size)
    )
    self.bv_distribution = GaussianDistribution((self.output_size,))

    # Sample initial weights and register buffers
    # Forget gate
    self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
    self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))

    # Input gate
    self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
    self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))

    # Candidate gate
    self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
    self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))

    # Output gate
    self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
    self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))

    # Final output layer
    self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
    self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))

3.5.1 __call__(inputs, init_states=None)

Performs a forward pass through the Bayesian LSTM layer. If the layer is not frozen, it samples weights and bias from their respective distributions.

Parameters:

Name Type Description Default
inputs Array

Input tensor with token indices. Shape: [batch, seq_len, 1]

required
init_states Optional[tuple[Array, Array]]

Optional initial hidden and cell states

None

Returns:

Type Description
tuple[Array, tuple[Array, Array]]

Tuple of (output, (hidden_state, cell_state))

Source code in illia/nn/jax/lstm.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def __call__(
    self,
    inputs: jax.Array,
    init_states: Optional[tuple[jax.Array, jax.Array]] = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
    """
    Performs a forward pass through the Bayesian LSTM layer.
    If the layer is not frozen, it samples weights and bias
    from their respective distributions.

    Args:
        inputs: Input tensor with token indices. Shape: [batch, seq_len, 1]
        init_states: Optional initial hidden and cell states

    Returns:
        Tuple of (output, (hidden_state, cell_state))
    """

    # Sample weights if not frozen
    if not self.frozen:
        self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
        self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))
        self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
        self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))
        self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
        self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))
        self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
        self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))
        self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
        self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))
    else:
        if any(w is None for w in [self.wf, self.wi, self.wc, self.wo, self.wv]):
            self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
            self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))
            self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
            self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))
            self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
            self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))
            self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
            self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))
            self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
            self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))

    # Apply embedding layer to input indices
    inputs = jnp.squeeze(inputs, axis=-1)
    inputs = self.embedding(inputs)
    batch_size = jnp.shape(inputs)[0]
    seq_len = jnp.shape(inputs)[1]

    # Initialize h_t and c_t if init_states is None
    if init_states is None:
        h_t = jnp.zeros([batch_size, self.hidden_size])
        c_t = jnp.zeros([batch_size, self.hidden_size])
    else:
        h_t, c_t = init_states[0], init_states[1]

    # Process sequence
    for t in range(seq_len):
        # Shape: (batch_size, embedding_dim)
        x_t = inputs[:, t, :]

        # Concatenate input and hidden state
        # Shape: (batch_size, embedding_dim + hidden_size)
        z_t = jnp.concat([x_t, h_t], axis=1)

        # Forget gate
        ft = nnx.sigmoid(z_t @ self.wf.T + self.bf)

        # Input gate
        it = nnx.sigmoid(z_t @ self.wi.T + self.bi)

        # Candidate cell state
        can = nnx.tanh(z_t @ self.wc.T + self.bc)

        # Output gate
        ot = nnx.sigmoid(z_t @ self.wo.T + self.bo)

        # Update cell state
        c_t = c_t * ft + can * it

        # Update hidden state
        h_t = ot * nnx.tanh(c_t)

    # Compute final output
    y_t = h_t @ self.wv.T + self.bv

    return y_t, (h_t, c_t)

3.5.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Returns:

Type Description
None

None.

Source code in illia/nn/jax/lstm.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.

    Returns:
        None.
    """

    # Set indicator
    self.frozen = True

    # Freeze embedding layer
    self.embedding.freeze()

    # Forget gate
    if self.wf is None:
        self.wf = nnx.Param(self.wf_distribution.sample(self.rngs))
    if self.bf is None:
        self.bf = nnx.Param(self.bf_distribution.sample(self.rngs))
    self.wf = jax.lax.stop_gradient(self.wf)
    self.bf = jax.lax.stop_gradient(self.bf)

    # Input gate
    if self.wi is None:
        self.wi = nnx.Param(self.wi_distribution.sample(self.rngs))
    if self.bi is None:
        self.bi = nnx.Param(self.bi_distribution.sample(self.rngs))
    self.wi = jax.lax.stop_gradient(self.wi)
    self.bi = jax.lax.stop_gradient(self.bi)

    # Candidate gate
    if self.wc is None:
        self.wc = nnx.Param(self.wc_distribution.sample(self.rngs))
    if self.bc is None:
        self.bc = nnx.Param(self.bc_distribution.sample(self.rngs))
    self.wc = jax.lax.stop_gradient(self.wc)
    self.bc = jax.lax.stop_gradient(self.bc)

    # Output gate
    if self.wo is None:
        self.wo = nnx.Param(self.wo_distribution.sample(self.rngs))
    if self.bo is None:
        self.bo = nnx.Param(self.bo_distribution.sample(self.rngs))
    self.wo = jax.lax.stop_gradient(self.wo)
    self.bo = jax.lax.stop_gradient(self.bo)

    # Final output layer
    if self.wv is None:
        self.wv = nnx.Param(self.wv_distribution.sample(self.rngs))
    if self.bv is None:
        self.bv = nnx.Param(self.bv_distribution.sample(self.rngs))
    self.wv = jax.lax.stop_gradient(self.wv)
    self.bv = jax.lax.stop_gradient(self.bv)

3.5.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Array

tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/jax/lstm.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        tuple containing KL divergence cost and total number of
        parameters.
    """

    # Compute log probs for each pair of weights and bias
    log_probs_f = self.wf_distribution.log_prob(
        jnp.asarray(self.wf)
    ) + self.bf_distribution.log_prob(jnp.asarray(self.bf))
    log_probs_i = self.wi_distribution.log_prob(
        jnp.asarray(self.wi)
    ) + self.bi_distribution.log_prob(jnp.asarray(self.bi))
    log_probs_c = self.wc_distribution.log_prob(
        jnp.asarray(self.wc)
    ) + self.bc_distribution.log_prob(jnp.asarray(self.bc))
    log_probs_o = self.wo_distribution.log_prob(
        jnp.asarray(self.wo)
    ) + self.bo_distribution.log_prob(jnp.asarray(self.bo))
    log_probs_v = self.wv_distribution.log_prob(
        jnp.asarray(self.wv)
    ) + self.bv_distribution.log_prob(jnp.asarray(self.bv))

    # Compute the total loss
    log_probs = log_probs_f + log_probs_i + log_probs_c + log_probs_o + log_probs_v

    # Compute number of parameters
    num_params = (
        self.wf_distribution.num_params
        + self.bf_distribution.num_params
        + self.wi_distribution.num_params
        + self.bi_distribution.num_params
        + self.wc_distribution.num_params
        + self.bc_distribution.num_params
        + self.wo_distribution.num_params
        + self.bo_distribution.num_params
        + self.wv_distribution.num_params
        + self.bv_distribution.num_params
    )

    return log_probs, num_params