Skip to content

2. Neural Network

2.1 linear

This module contains the code for Linear Bayesian layer.

2.1.1 Linear(input_size, output_size, weights_distribution=None, bias_distribution=None, *, use_bias=True, precision=None, dot_general=lax.dot_general)

This class is the bayesian implementation of the Linear class.

This is the constructor of the Linear class.

Parameters:

Name Type Description Default
input_size int

Size of the input features.

required
output_size int

Size of the output features.

required
weights_distribution Optional[GaussianDistribution]

Prior distribution of the weights.

None
bias_distribution Optional[GaussianDistribution]

Prior distribution of the bias.

None
use_bias bool

Whether to include a bias term in the layer.

True
precision PrecisionLike

Precision used in dot product operations.

None
dot_general DotGeneralT

Function for computing generalized dot products.

dot_general
Source code in illia/nn/jax/linear.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(
    self,
    input_size: int,
    output_size: int,
    weights_distribution: Optional[GaussianDistribution] = None,
    bias_distribution: Optional[GaussianDistribution] = None,
    *,
    use_bias: bool = True,
    precision: PrecisionLike = None,
    dot_general: DotGeneralT = lax.dot_general,
) -> None:
    """
    This is the constructor of the Linear class.

    Args:
        input_size: Size of the input features.
        output_size: Size of the output features.
        weights_distribution: Prior distribution of the weights.
        bias_distribution: Prior distribution of the bias.
        use_bias: Whether to include a bias term in the layer.
        precision: Precision used in dot product operations.
        dot_general: Function for computing generalized dot
            products.
    """

    # Call super class constructor
    super().__init__()

    # Set attributes
    self.backend_params: dict[str, Any] = {
        "use_bias": True,
        "precision": precision,
        "dot_general": dot_general,
    }

    # Set weights prior
    if weights_distribution is None:
        self.weights_distribution = GaussianDistribution((input_size, output_size))
    else:
        self.weights_distribution = weights_distribution

    # Set bias prior
    if self.backend_params["use_bias"]:
        if bias_distribution is None:
            self.bias_distribution = GaussianDistribution((output_size,))
        else:
            self.bias_distribution = self.bias_distribution

    return None

2.1.1.1 __call__(inputs)

This method is the forward pass of the model.

Parameters:

Name Type Description Default
inputs Array

Inputs of the model. Dimensions: [*, input size].

required

Returns:

Type Description
Array

Output tensor. Dimension: [*, output size].

Source code in illia/nn/jax/linear.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def __call__(self, inputs: jax.Array) -> jax.Array:
    """
    This method is the forward pass of the model.

    Args:
        inputs: Inputs of the model. Dimensions: [*, input size].

    Returns:
        Output tensor. Dimension: [*, output size].
    """

    # Sample if model not frozen
    if not self.frozen:
        # Sample weights
        self.weights = self.weights_distribution.sample()

        # Sample bias
        if self.backend_params["use_bias"]:
            self.bias = self.bias_distribution.sample()

    # Compute ouputs
    inputs, _, _ = dtypes.promote_dtype(
        (inputs, self.weights, self.bias), dtype=self.dtype  # type: ignore
    )
    outputs = self.dot_general(  # type: ignore
        inputs,
        self.weights,
        (((inputs.ndim - 1,), (0,)), ((), ())),
        precision=self.precision,  # type: ignore
    )
    if self.backend_params["use_bias"]:
        outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))

    return outputs

2.1.1.2 kl_cost()

Computes the Kullback-Leibler (KL) divergence cost for the layer's weights and bias.

Returns:

Type Description
Array

KL divergence cost.

int

Total number of parameters.

Source code in illia/nn/jax/linear.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def kl_cost(self) -> tuple[jax.Array, int]:
    """
    Computes the Kullback-Leibler (KL) divergence cost for the
    layer's weights and bias.

    Returns:
        KL divergence cost.
        Total number of parameters.
    """

    # Compute log probs
    log_probs: jax.Array = self.weights_distribution.log_prob(
        self.weights
    ) + self.bias_distribution.log_prob(self.bias)

    # Compute the number of parameters
    num_params: int = (
        self.weights_distribution.num_params + self.bias_distribution.num_params
    )

    return log_probs, num_params