Skip to content

9. Neural Network

9.1 cg_conv

This module contains the code for the Crystal Graph Convolutional operator.

9.1.1 CGConv(channels, dim=0, aggr='add', **kwargs)

Definition of the Crystal Graph Convolutional operator.

"Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties" (https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301).

The operation is defined as:

\[ \mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)}\sigma ( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f ) \odot g ( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s ) \]

where \(\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]\) denotes the concatenation of central node features, neighboring node features, and edge features. In addition, \(\sigma\) and \(g\) denote the sigmoid and softplus functions, respectively.

Parameters:

Name Type Description Default
channels int or tuple

The size of each input sample. A tuple corresponds to the sizes of source and target dimensionalities.

required
dim int

The edge feature dimensionality.

0
aggr str

The aggregation operator to use ("add", "mean", "max").

'add'
**kwargs optional

Additional arguments for :class:torch_geometric.nn.conv.MessagePassing.

{}
Shapes
  • input: node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, edge indices \((2, |\mathcal{E}|)\), edge features \((|\mathcal{E}|, D)\) (optional)
  • output: node features \((|\mathcal{V}|, F)\) or \((|\mathcal{V_t}|, F_{t})\) if bipartite
Source code in illia/nn/pyg/cg_conv.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def __init__(
    self,
    channels: Union[int, tuple[int, int]],
    dim: int = 0,
    aggr: str = "add",
    **kwargs,
):
    r"""
    "Crystal Graph Convolutional Neural Networks for an Accurate
    and Interpretable Prediction of Material Properties"
    (https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301).

    The operation is defined as:

    $$
    \mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in
    \mathcal{N}(i)}\sigma ( \mathbf{z}_{i,j} \mathbf{W}_f +
    \mathbf{b}_f )
    \odot g ( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s )
    $$

    where \(\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j,
    \mathbf{e}_{i,j} ]\)
    denotes the concatenation of central node features, neighboring
    node features, and edge features. In addition, \(\sigma\) and
    \(g\) denote the sigmoid and softplus functions, respectively.

    Args:
        channels (int or tuple): The size of each input sample. A
            tuple corresponds to the sizes of source and target
            dimensionalities.
        dim (int, optional): The edge feature dimensionality.
        aggr (str, optional): The aggregation operator to use
            ("add", "mean", "max").
        **kwargs (optional): Additional arguments for
            :class:`torch_geometric.nn.conv.MessagePassing`.

    Shapes:
        - **input:**
        node features \((|\mathcal{V}|, F)\) or \(((|\mathcal{V_s}|,
        F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite,
        edge indices \((2, |\mathcal{E}|)\),
        edge features \((|\mathcal{E}|, D)\) *(optional)*
        - **output:** node features \((|\mathcal{V}|, F)\) or
            \((|\mathcal{V_t}|, F_{t})\) if bipartite
    """

    # Call super class constructor
    super().__init__(aggr=aggr, **kwargs)

    self.channels = channels
    self.dim = dim

    if isinstance(channels, int):
        channels = (channels, channels)

    self.lin_f = Linear(sum(channels) + dim, channels[1])
    self.lin_s = Linear(sum(channels) + dim, channels[1])

9.1.1.1 __repr__()

Returns a string representation of the module.

Source code in illia/nn/pyg/cg_conv.py
138
139
140
141
142
143
def __repr__(self) -> str:
    """
    Returns a string representation of the module.
    """

    return f"{self.__class__.__name__}({self.channels}, dim={self.dim})"

9.1.1.2 forward(x, edge_index, edge_attr=None)

Performs a forward pass of the convolutional layer.

Parameters:

Name Type Description Default
x Union[Tensor, PairTensor]

Input node features, either as a single tensor or a pair of tensors if bipartite.

required
edge_index Adj

Edge indices.

required
edge_attr OptTensor

Optional edge features.

None

Returns:

Type Description
Tensor

The output node features.

Source code in illia/nn/pyg/cg_conv.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
def forward(
    self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor = None
) -> Tensor:
    """
    Performs a forward pass of the convolutional layer.

    Args:
        x: Input node features, either as a single tensor or a pair
            of tensors if bipartite.
        edge_index: Edge indices.
        edge_attr: Optional edge features.

    Returns:
        The output node features.
    """

    if isinstance(x, Tensor):
        x = (x, x)

    # Propagate_type: (x: PairTensor, edge_attr: OptTensor)
    out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
    out = out + x[1]
    return out

9.1.1.3 message(x_i, x_j, edge_attr)

Constructs messages to be passed to neighboring nodes.

Parameters:

Name Type Description Default
x_i Tensor

Central node features.

required
x_j Tensor

Neighboring node features.

required
edge_attr OptTensor

Optional edge features.

required

Returns:

Type Description
Tensor

The messages to be aggregated.

Source code in illia/nn/pyg/cg_conv.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
    """
    Constructs messages to be passed to neighboring nodes.

    Args:
        x_i: Central node features.
        x_j: Neighboring node features.
        edge_attr: Optional edge features.

    Returns:
        The messages to be aggregated.
    """

    if edge_attr is None:
        z = torch.cat([x_i, x_j], dim=-1)
    else:
        z = torch.cat([x_i, x_j, edge_attr], dim=-1)

    return self.lin_f(z).sigmoid() * F.softplus(self.lin_s(z))

9.1.1.4 reset_parameters()

Resets the parameters of the linear layers.

Source code in illia/nn/pyg/cg_conv.py
84
85
86
87
88
89
90
91
92
def reset_parameters(self):
    """
    Resets the parameters of the linear layers.
    """

    self.lin_f.reset_parameters()
    self.lin_s.reset_parameters()
    if self.bn is not None:
        self.bn.reset_parameters()