10. Skip to content

10. Graph Neural Network Layers

10.1 CGConv

Crystal Graph Convolutional operator for material property prediction.

Updates node features using neighboring nodes and edge features as:

x'_i = x_i + sum_{j in N(i)} sigmoid(z_ij W_f + b_f) *
       softplus(z_ij W_s + b_s)

where z_ij is the concatenation of central node features, neighbor features, and edge features. Applies element-wise sigmoid and softplus functions.

Parameters:

Name Type Description Default
channels int | tuple[int, int]

Size of input features. If tuple, represents source and target feature dimensions.

required
dim int

Dimensionality of edge features.

0
aggr str

Aggregation method ("add", "mean", "max").

'add'
**kwargs Any

Additional arguments for MessagePassing.

{}

Returns:

Type Description

None.

Shapes
  • input: node features (|V|, F) or ((|Vs|, Fs), (|Vt|, Ft)) if bipartite, edge indices (2, |E|), edge features (|E|, D) optional.
  • output: node features (|V|, F) or (|Vt|, Ft) if bipartite.
Notes

Based on "Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties" (https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301)

Source code in illia/nn/pyg/cg_conv.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class CGConv(MessagePassing):
    r"""
    Crystal Graph Convolutional operator for material property prediction.

    Updates node features using neighboring nodes and edge features as:

        x'_i = x_i + sum_{j in N(i)} sigmoid(z_ij W_f + b_f) *
               softplus(z_ij W_s + b_s)

    where z_ij is the concatenation of central node features, neighbor
    features, and edge features. Applies element-wise sigmoid and
    softplus functions.

    Args:
        channels: Size of input features. If tuple, represents source and
            target feature dimensions.
        dim: Dimensionality of edge features.
        aggr: Aggregation method ("add", "mean", "max").
        **kwargs: Additional arguments for MessagePassing.

    Returns:
        None.

    Shapes:
        - input: node features (|V|, F) or ((|Vs|, Fs), (|Vt|, Ft)) if
          bipartite, edge indices (2, |E|), edge features (|E|, D) optional.
        - output: node features (|V|, F) or (|Vt|, Ft) if bipartite.

    Notes:
        Based on "Crystal Graph Convolutional Neural Networks for an
        Accurate and Interpretable Prediction of Material Properties"
        (https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301)
    """

    def __init__(
        self,
        channels: int | tuple[int, int],
        dim: int = 0,
        aggr: str = "add",
        **kwargs: Any,
    ) -> None:
        """
        Initializes the CGConv layer with linear transformations.

        Args:
            channels: Size of input features. Tuple for source and target.
            dim: Dimensionality of edge features.
            aggr: Aggregation operator ("add", "mean", "max").
            **kwargs: Extra arguments passed to the base class.

        Returns:
            None.
        """

        super().__init__(aggr=aggr, **kwargs)

        self.channels = channels
        self.dim = dim

        if isinstance(channels, int):
            channels = (channels, channels)

        # Define linear layers
        self.lin_f = Linear(sum(channels) + dim, channels[1])
        self.lin_s = Linear(sum(channels) + dim, channels[1])

    def reset_parameters(self) -> None:
        """
        Resets parameters of the linear layers and optional batch norm.

        Returns:
            None.
        """

        self.lin_f.reset_parameters()
        self.lin_s.reset_parameters()
        if self.bn is not None:
            self.bn.reset_parameters()

    def forward(
        self, x: Tensor | PairTensor, edge_index: Adj, edge_attr: OptTensor = None
    ) -> Tensor:
        """
        Performs a forward pass of the convolutional layer.

        Args:
            x: Input node features, as a single tensor or a pair if bipartite.
            edge_index: Edge indices.
            edge_attr: Optional edge features.

        Returns:
            Node features after applying the convolution.
        """

        if isinstance(x, Tensor):
            x = (x, x)

        # Propagate_type: (x: PairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
        out = out + x[1]
        return out

    def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
        """
        Constructs messages passed to neighboring nodes.

        Args:
            x_i: Central node features.
            x_j: Neighboring node features.
            edge_attr: Optional edge features.

        Returns:
            Aggregated messages for neighbors.
        """

        if edge_attr is None:
            z = torch.cat([x_i, x_j], dim=-1)
        else:
            z = torch.cat([x_i, x_j, edge_attr], dim=-1)

        # pylint: disable=E1102
        return self.lin_f(z).sigmoid() * F.softplus(input=self.lin_s(z))

    def __repr__(self) -> str:
        """
        Returns a string representation of the module.

        Returns:
            String with class name, channels, and edge feature dimension.
        """

        return f"{self.__class__.__name__}({self.channels}, dim={self.dim})"

10.1.1 __init__(channels, dim=0, aggr='add', **kwargs)

Initializes the CGConv layer with linear transformations.

Parameters:

Name Type Description Default
channels int | tuple[int, int]

Size of input features. Tuple for source and target.

required
dim int

Dimensionality of edge features.

0
aggr str

Aggregation operator ("add", "mean", "max").

'add'
**kwargs Any

Extra arguments passed to the base class.

{}

Returns:

Type Description
None

None.

Source code in illia/nn/pyg/cg_conv.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self,
    channels: int | tuple[int, int],
    dim: int = 0,
    aggr: str = "add",
    **kwargs: Any,
) -> None:
    """
    Initializes the CGConv layer with linear transformations.

    Args:
        channels: Size of input features. Tuple for source and target.
        dim: Dimensionality of edge features.
        aggr: Aggregation operator ("add", "mean", "max").
        **kwargs: Extra arguments passed to the base class.

    Returns:
        None.
    """

    super().__init__(aggr=aggr, **kwargs)

    self.channels = channels
    self.dim = dim

    if isinstance(channels, int):
        channels = (channels, channels)

    # Define linear layers
    self.lin_f = Linear(sum(channels) + dim, channels[1])
    self.lin_s = Linear(sum(channels) + dim, channels[1])

10.1.2 __repr__()

Returns a string representation of the module.

Returns:

Type Description
str

String with class name, channels, and edge feature dimension.

Source code in illia/nn/pyg/cg_conv.py
138
139
140
141
142
143
144
145
146
def __repr__(self) -> str:
    """
    Returns a string representation of the module.

    Returns:
        String with class name, channels, and edge feature dimension.
    """

    return f"{self.__class__.__name__}({self.channels}, dim={self.dim})"

10.1.3 forward(x, edge_index, edge_attr=None)

Performs a forward pass of the convolutional layer.

Parameters:

Name Type Description Default
x Tensor | PairTensor

Input node features, as a single tensor or a pair if bipartite.

required
edge_index Adj

Edge indices.

required
edge_attr OptTensor

Optional edge features.

None

Returns:

Type Description
Tensor

Node features after applying the convolution.

Source code in illia/nn/pyg/cg_conv.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def forward(
    self, x: Tensor | PairTensor, edge_index: Adj, edge_attr: OptTensor = None
) -> Tensor:
    """
    Performs a forward pass of the convolutional layer.

    Args:
        x: Input node features, as a single tensor or a pair if bipartite.
        edge_index: Edge indices.
        edge_attr: Optional edge features.

    Returns:
        Node features after applying the convolution.
    """

    if isinstance(x, Tensor):
        x = (x, x)

    # Propagate_type: (x: PairTensor, edge_attr: OptTensor)
    out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
    out = out + x[1]
    return out

10.1.4 message(x_i, x_j, edge_attr)

Constructs messages passed to neighboring nodes.

Parameters:

Name Type Description Default
x_i Tensor

Central node features.

required
x_j Tensor

Neighboring node features.

required
edge_attr OptTensor

Optional edge features.

required

Returns:

Type Description
Tensor

Aggregated messages for neighbors.

Source code in illia/nn/pyg/cg_conv.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
    """
    Constructs messages passed to neighboring nodes.

    Args:
        x_i: Central node features.
        x_j: Neighboring node features.
        edge_attr: Optional edge features.

    Returns:
        Aggregated messages for neighbors.
    """

    if edge_attr is None:
        z = torch.cat([x_i, x_j], dim=-1)
    else:
        z = torch.cat([x_i, x_j, edge_attr], dim=-1)

    # pylint: disable=E1102
    return self.lin_f(z).sigmoid() * F.softplus(input=self.lin_s(z))

10.1.5 reset_parameters()

Resets parameters of the linear layers and optional batch norm.

Returns:

Type Description
None

None.

Source code in illia/nn/pyg/cg_conv.py
81
82
83
84
85
86
87
88
89
90
91
92
def reset_parameters(self) -> None:
    """
    Resets parameters of the linear layers and optional batch norm.

    Returns:
        None.
    """

    self.lin_f.reset_parameters()
    self.lin_s.reset_parameters()
    if self.bn is not None:
        self.bn.reset_parameters()