Skip to content

5. Neural Network

5.1 conv1d

This module contains the code for the bayesian Conv1D.

5.1.1 Conv1D(input_channels, output_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, weights_distribution=None, bias_distribution=None)

This class is the bayesian implementation of the Conv1D class.

Definition of a Bayesian Convolution 2D layer.

Parameters:

Name Type Description Default
input_channels int

Number of channels in the input image.

required
output_channels int

Number of channels produced by the convolution.

required
kernel_size int

Size of the convolving kernel.

required
stride int

Stride of the convolution. Deafults to 1.

1
padding int

Padding added to all four sides of the input. Defaults to 0.

0
dilation int

Spacing between kernel elements.

1
groups int

Number of blocked connections from input channels to output channels. Defaults to 1.

1
weights_distribution Optional[GaussianDistribution]

The distribution for the weights.

None
bias_distribution Optional[GaussianDistribution]

The distribution for the bias.

None
Source code in illia/nn/torch/conv1d.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self,
    input_channels: int,
    output_channels: int,
    kernel_size: int,
    stride: int = 1,
    padding: int = 0,
    dilation: int = 1,
    groups: int = 1,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
) -> None:
    """
    Definition of a Bayesian Convolution 2D layer.

    Args:
        input_channels: Number of channels in the input image.
        output_channels: Number of channels produced by the
            convolution.
        kernel_size: Size of the convolving kernel.
        stride: Stride of the convolution. Deafults to 1.
        padding: Padding added to all four sides of the input.
            Defaults to 0.
        dilation: Spacing between kernel elements.
        groups: Number of blocked connections from input channels
            to output channels. Defaults to 1.
        weights_distribution: The distribution for the weights.
        bias_distribution: The distribution for the bias.
    """

    # Call super class constructor
    super().__init__()

    # Set attributes
    self.conv_params: tuple[int, ...] = (stride, padding, dilation, groups)

    # Set weights distribution
    if weights_distribution is None:
        # Define weights distribution
        self.weights_distribution: GaussianDistribution = GaussianDistribution(
            (output_channels, input_channels // groups, kernel_size)
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias distribution
    if bias_distribution is None:
        # Define weights distribution
        self.bias_distribution: GaussianDistribution = GaussianDistribution(
            (output_channels,)
        )
    else:
        self.bias_distribution = bias_distribution

    # Sample initial weights
    weights = self.weights_distribution.sample()
    bias = self.bias_distribution.sample()

    # Register buffers
    self.register_buffer("weights", weights)
    self.register_buffer("bias", bias)

5.1.1.1 forward(inputs)

Performs a forward pass through the Bayesian Convolution 2D layer. If the layer is not frozen, it samples weights and bias from their respective distributions. If the layer is frozen and the weights or bias are not initialized, it also performs sampling.

Parameters:

Name Type Description Default
inputs Tensor

Input tensor to the layer. Dimensions: [batch, input channels, input width, input height].

required

Returns:

Type Description
Tensor

Output tensor after passing through the layer. Dimensions: [batch, output channels, output width, output height].

Source code in illia/nn/torch/conv1d.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
    """
    Performs a forward pass through the Bayesian Convolution 2D
    layer. If the layer is not frozen, it samples weights and bias
    from their respective distributions. If the layer is frozen
    and the weights or bias are not initialized, it also performs
    sampling.

    Args:
        inputs: Input tensor to the layer. Dimensions: [batch,
            input channels, input width, input height].

    Returns:
        Output tensor after passing through the layer. Dimensions:
            [batch, output channels, output width, output height].
    """

    # Forward depending of frozen state
    if not self.frozen:
        self.weights = self.weights_distribution.sample()
        self.bias = self.bias_distribution.sample()
    elif self.weights is None or self.bias is None:
        raise ValueError("Module has been frozen with undefined weights")

    # Execute torch forward
    return F.conv1d(
        inputs, self.weights, self.bias, *self.conv_params  # type: ignore
    )

5.1.1.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Source code in illia/nn/torch/conv1d.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@torch.jit.export
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:  # type: ignore
        self.weights = self.weights_distribution.sample()

    # Sample bias is they are undefined
    if self.bias is None:  # type: ignore
        self.bias = self.bias_distribution.sample()

    # Detach weights and bias
    self.weights = self.weights.detach()
    self.bias = self.bias.detach()

5.1.1.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Tensor

Tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/torch/conv1d.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@torch.jit.export
def kl_cost(self) -> tuple[torch.Tensor, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        Tuple containing KL divergence cost and total number of
        parameters.
    """

    # Compute log probs
    log_probs: torch.Tensor = self.weights_distribution.log_prob(
        self.weights
    ) + self.bias_distribution.log_prob(self.bias)

    # Compute number of parameters
    num_params: int = (
        self.weights_distribution.num_params() + self.bias_distribution.num_params()
    )

    return log_probs, num_params

5.2 conv2d

This module contains the code for the bayesian Conv2D.

5.2.1 Conv2D(input_channels, output_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, weights_distribution=None, bias_distribution=None)

This class is the bayesian implementation of the Conv2D class.

Definition of a Bayesian Convolution 2D layer.

Parameters:

Name Type Description Default
kernel_size Union[int, tuple[int, int]]

Size of the convolving kernel.

required
stride Union[int, tuple[int, int]]

Stride of the convolution. Deafults to 1.

1
padding Union[int, tuple[int, int]]

Padding added to all four sides of the input. Defaults to 0.

0
dilation Union[int, tuple[int, int]]

Spacing between kernel elements.

1
groups int

Number of blocked connections from input channels to output channels. Defaults to 1.

1
weights_distribution Optional[GaussianDistribution]

The distribution for the weights.

None
bias_distribution Optional[GaussianDistribution]

The distribution for the bias.

None
Source code in illia/nn/torch/conv2d.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    input_channels: int,
    output_channels: int,
    kernel_size: Union[int, tuple[int, int]],
    stride: Union[int, tuple[int, int]] = 1,
    padding: Union[int, tuple[int, int]] = 0,
    dilation: Union[int, tuple[int, int]] = 1,
    groups: int = 1,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
) -> None:
    """
    Definition of a Bayesian Convolution 2D layer.

    Args:
        kernel_size: Size of the convolving kernel.
        stride: Stride of the convolution. Deafults to 1.
        padding: Padding added to all four sides of the input.
            Defaults to 0.
        dilation: Spacing between kernel elements.
        groups: Number of blocked connections from input channels
            to output channels. Defaults to 1.
        weights_distribution: The distribution for the weights.
        bias_distribution: The distribution for the bias.
    """

    # Call super class constructor
    super().__init__()

    # Set attributes
    self.conv_params: tuple[Union[int, tuple[int, int]], ...] = (
        stride,
        padding,
        dilation,
        groups,
    )

    # Set weights distribution
    if weights_distribution is None:
        # Extend kernel if we only have 1 value
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)

        # Define weights distribution
        self.weights_distribution: GaussianDistribution = GaussianDistribution(
            (output_channels, input_channels // groups, *kernel_size)
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias distribution
    if bias_distribution is None:
        # Define weights distribution
        self.bias_distribution: GaussianDistribution = GaussianDistribution(
            (output_channels,)
        )
    else:
        self.bias_distribution = bias_distribution

    # Sample initial weights
    weights = self.weights_distribution.sample()
    bias = self.bias_distribution.sample()

    # Register buffers
    self.register_buffer("weights", weights)
    self.register_buffer("bias", bias)

5.2.1.1 forward(inputs)

Performs a forward pass through the Bayesian Convolution 2D layer. If the layer is not frozen, it samples weights and bias from their respective distributions. If the layer is frozen and the weights or bias are not initialized, it also performs sampling.

Parameters:

Name Type Description Default
inputs Tensor

Input tensor to the layer. Dimensions: [batch, input channels, input width, input height].

required

Returns:

Type Description
Tensor

Output tensor after passing through the layer. Dimensions: [batch, output channels, output width, output height].

Source code in illia/nn/torch/conv2d.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
    """
    Performs a forward pass through the Bayesian Convolution 2D
    layer. If the layer is not frozen, it samples weights and bias
    from their respective distributions. If the layer is frozen
    and the weights or bias are not initialized, it also performs
    sampling.

    Args:
        inputs: Input tensor to the layer. Dimensions: [batch,
            input channels, input width, input height].

    Returns:
        Output tensor after passing through the layer. Dimensions:
            [batch, output channels, output width, output height].
    """

    # Forward depending of frozen state
    if not self.frozen:
        self.weights = self.weights_distribution.sample()
        self.bias = self.bias_distribution.sample()
    else:
        if self.weights is None or self.bias is None:
            self.weights = self.weights_distribution.sample()
            self.bias = self.bias_distribution.sample()

    # Execute torch forward
    return F.conv2d(  # pylint: disable=E1102
        inputs, self.weights, self.bias, *self.conv_params  # type: ignore
    )

5.2.1.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Source code in illia/nn/torch/conv2d.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@torch.jit.export
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:  # type: ignore
        self.weights = self.weights_distribution.sample()

    # Sample bias is they are undefined
    if self.bias is None:  # type: ignore
        self.bias = self.bias_distribution.sample()

    # Detach weights and bias
    self.weights = self.weights.detach()
    self.bias = self.bias.detach()

5.2.1.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Tensor

Tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/torch/conv2d.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@torch.jit.export
def kl_cost(self) -> tuple[torch.Tensor, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        Tuple containing KL divergence cost and total number of
        parameters.
    """

    # Compute log probs
    log_probs: torch.Tensor = self.weights_distribution.log_prob(
        self.weights
    ) + self.bias_distribution.log_prob(self.bias)

    # Compute number of parameters
    num_params: int = (
        self.weights_distribution.num_params() + self.bias_distribution.num_params()
    )

    return log_probs, num_params

5.3 embedding

This module contains the code for bayesian Embedding layer.

5.3.1 Embedding(num_embeddings, embeddings_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, weights_distribution=None)

This class is the bayesian implementation of the Embedding class.

This method is the constructor of the embedding class.

Parameters:

Name Type Description Default
num_embeddings int

size of the dictionary of embeddings.

required
embeddings_dim int

the size of each embedding vector.

required
padding_idx Optional[int]

If specified, the entries at padding_idx do not contribute to the gradient.

None
max_norm Optional[float]

If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm.

None
norm_type float

The p of the p-norm to compute for the max_norm option.

2.0
scale_grad_by_freq bool

If given, this will scale gradients by the inverse of frequency of the words in the mini-batch.

False
sparse bool

If True, gradient w.r.t. weight matrix will be a sparse tensor.

False
weights_distribution Optional[GaussianDistribution]

distribution for the weights of the layer.

None
Source code in illia/nn/torch/embedding.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def __init__(
    self,
    num_embeddings: int,
    embeddings_dim: int,
    padding_idx: Optional[int] = None,
    max_norm: Optional[float] = None,
    norm_type: float = 2.0,
    scale_grad_by_freq: bool = False,
    sparse: bool = False,
    weights_distribution: Optional[GaussianDistribution] = None,
) -> None:
    """
    This method is the constructor of the embedding class.

    Args:
        num_embeddings: size of the dictionary of embeddings.
        embeddings_dim: the size of each embedding vector.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient.
        max_norm: If given, each embedding vector with norm larger
            than max_norm is renormalized to have norm max_norm.
        norm_type: The p of the p-norm to compute for the max_norm
            option.
        scale_grad_by_freq: If given, this will scale gradients by
            the inverse of frequency of the words in the
            mini-batch.
        sparse: If True, gradient w.r.t. weight matrix will be a
            sparse tensor.
        weights_distribution: distribution for the weights of the
            layer.
    """

    # Call super class constructor
    super().__init__()

    # Set embeddings atributtes
    self.embedding_params: tuple[Any, ...] = (
        padding_idx,
        max_norm,
        norm_type,
        scale_grad_by_freq,
        sparse,
    )

    # Set weights distribution
    if weights_distribution is None:
        self.weights_distribution = GaussianDistribution(
            (num_embeddings, embeddings_dim)
        )
    else:
        self.weights_distribution = weights_distribution

    # Sample initial weights
    weights = self.weights_distribution.sample()

    # Register buffers
    self.register_buffer("weights", weights)

5.3.1.1 forward(inputs)

This method is the forward pass of the layer.

Parameters:

Name Type Description Default
inputs Tensor

input tensor. Dimensions: [*].

required

Raises:

Type Description
ValueError

Module has been frozen with undefined weights.

Returns:

Type Description
Tensor

outputs tensor. Dimension: [*, embedding dim].

Source code in illia/nn/torch/embedding.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
    """
    This method is the forward pass of the layer.

    Args:
        inputs: input tensor. Dimensions: [*].

    Raises:
        ValueError: Module has been frozen with undefined weights.

    Returns:
        outputs tensor. Dimension: [*, embedding dim].
    """

    # Forward depeding of frozen state
    if not self.frozen:
        self.weights = self.weights_distribution.sample()
    elif self.weights is None:
        raise ValueError("Module has been frozen with undefined weights")

    # Run torch forward
    outputs: torch.Tensor = F.embedding(
        inputs, self.weights, *self.embedding_params  # type: ignore
    )

    return outputs

5.3.1.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Source code in illia/nn/torch/embedding.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@torch.jit.export
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.
    """

    # set indicator
    self.frozen = True

    # sample weights if they are undefined
    if self.weights is None:  # type: ignore
        self.weights = self.weights_distribution.sample()

    # detach weights
    self.weights = self.weights.detach()

5.3.1.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Tensor

Tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/torch/embedding.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
@torch.jit.export
def kl_cost(self) -> tuple[torch.Tensor, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        Tuple containing KL divergence cost and total number of
        parameters.
    """

    # get log posterior and log prior
    log_probs: torch.Tensor = self.weights_distribution.log_prob(self.weights)

    # get number of parameters
    num_params: int = self.weights_distribution.num_params()

    return log_probs, num_params

5.4 linear

This module contains the code for Linear Bayesian layer.

5.4.1 Linear(input_size, output_size, weights_distribution=None, bias_distribution=None)

This class is the bayesian implementation of the torch Linear layer.

This is the constructor of the Linear class.

Parameters:

Name Type Description Default
input_size int

Input size of the linear layer.

required
output_size int

Output size of the linear layer.

required
weights_distribution Optional[GaussianDistribution]

GaussianDistribution for the weights of the layer. Defaults to None.

None
bias_distribution Optional[GaussianDistribution]

GaussianDistribution for the bias of the layer. Defaults to None.

None
Source code in illia/nn/torch/linear.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    input_size: int,
    output_size: int,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
) -> None:
    """
    This is the constructor of the Linear class.

    Args:
        input_size: Input size of the linear layer.
        output_size: Output size of the linear layer.
        weights_distribution: GaussianDistribution for the weights of the
            layer. Defaults to None.
        bias_distribution: GaussianDistribution for the bias of the layer.
            Defaults to None.
    """

    # Call super-class constructor
    super().__init__()

    # Set weights distribution
    if weights_distribution is None:
        self.weights_distribution: GaussianDistribution = GaussianDistribution(
            (output_size, input_size)
        )
    else:
        self.weights_distribution = weights_distribution

    # Set bias distribution
    if bias_distribution is None:
        self.bias_distribution: GaussianDistribution = GaussianDistribution(
            (output_size,)
        )
    else:
        self.bias_distribution = bias_distribution

    # Sample initial weights
    weights = self.weights_distribution.sample()
    bias = self.bias_distribution.sample()

    # Register buffers
    self.register_buffer("weights", weights)
    self.register_buffer("bias", bias)

5.4.1.1 forward(inputs)

This method is the forward pass of the layer.

Parameters:

Name Type Description Default
inputs Tensor

input tensor. Dimensions: [batch, *].

required

Raises:

Type Description
ValueError

Module has been frozen with undefined weights.

Returns:

Type Description
Tensor

outputs tensor. Dimensions: [batch, *].

Source code in illia/nn/torch/linear.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
    """
    This method is the forward pass of the layer.

    Args:
        inputs: input tensor. Dimensions: [batch, *].

    Raises:
        ValueError: Module has been frozen with undefined weights.

    Returns:
        outputs tensor. Dimensions: [batch, *].
    """

    # Check if layer is frozen
    if not self.frozen:
        self.weights = self.weights_distribution.sample()
        self.bias = self.bias_distribution.sample()
    elif self.weights is None or self.bias is None:
        raise ValueError("Module has been frozen with undefined weights")

    # compute outputs
    outputs: torch.Tensor = F.linear(inputs, self.weights, self.bias)

    return outputs

5.4.1.2 freeze()

Freezes the current module and all submodules that are instances of BayesianModule. Sets the frozen state to True.

Source code in illia/nn/torch/linear.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
@torch.jit.export
def freeze(self) -> None:
    """
    Freezes the current module and all submodules that are instances
    of BayesianModule. Sets the frozen state to True.
    """

    # Set indicator
    self.frozen = True

    # Sample weights if they are undefined
    if self.weights is None:  # type: ignore
        self.weights = self.weights_distribution.sample()

    # Sample bias is they are undefined
    if self.bias is None:  # type: ignore
        self.bias = self.bias_distribution.sample()

    # Detach weights and bias
    self.weights = self.weights.detach()
    self.bias = self.bias.detach()

5.4.1.3 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Tensor

Tuple containing KL divergence cost and total number of

int

parameters.

Source code in illia/nn/torch/linear.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@torch.jit.export
def kl_cost(self) -> tuple[torch.Tensor, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        Tuple containing KL divergence cost and total number of
        parameters.
    """

    # Compute log probs
    log_probs: torch.Tensor = self.weights_distribution.log_prob(
        self.weights
    ) + self.bias_distribution.log_prob(self.bias)

    # Compute the number of parameters
    num_params: int = (
        self.weights_distribution.num_params() + self.bias_distribution.num_params()
    )

    return log_probs, num_params