Skip to content

Graph neural networks

There are several kinds of Graph neural network (GNN) in the literature. We adopt the unifying framework described by Battaglia et al. [1] since this subsumes many of the different approaches. In accordance with this, we will sometimes refer to a GNN as a Graph Network (GN) although this term isn't in common use.

Graph structure

A GN graph is made up of a set of \(N^v\) nodes (vertices) \(V\), a set of \(N^e\) directed edges \(E\) connecting pairs of nodes (including self-connections), and a global attribute \(\myvec{u}\). Each node has an associated attribute vector \(\myvec{v}_i\), so we write the set of nodes as \(V=\{(\myvec{v}_i,i)\}_{i=1:N^v}\). Each directed edge has an associated attribute \(\myvec{e}_k\), so we write the set of edges as \(E=\{(\myvec{e}_k,r_k,s_k)\}_{k=1:N^e}\), where \(r_k\) is the index of the receiver node (i.e. the endpoint of the directed edge) and \(s_k\) is the index of the sender node (i.e. the start-point of the directed edge). This terminology comes from the idea of sending messages along edges. The attributes \(\myvec{u}, \myvec{v}_i, \myvec{e}_k\) could be scalars, vectors, tensors (e.g. a colour image), or something different altogether. We will simply denote them as vectors.


Note

The set notation in the previous paragraph is used for consistency with Battaglia et al. [1]. However, we differ from Battaglia et al. [1] by including the node index in the specification of each node, just as they are included in for the specification of edges. It is interesting to observe that this notation scales to a hypergraph by simply listing the indices of the nodes involved in each hyperedge, including the special case of a hyperedge involving a single node, which is synonymous with a node.


In standard terms, a GN Graph is a directed, attributed multigraph with a global attribute. A multigraph is a graph where there may be more than one edge between any pair of nodes. In what follows, we will refer to it as a 'graph' for brevity.

In the following diagram of a graph with four nodes and six edges, we show only the attribute values at each node and on each edge.

an example graph

Figure 1. An example graph.

The definition of this graph is as follows:

\[ V = \{(\myvec{v}_1,1), (\myvec{v}_2,2),(\myvec{v}_3,3),(\myvec{v}_4,4)\} \]
\[ E = \{(\myvec{e}_1,1,3),(\myvec{e}_2,2,1),(\myvec{e}_3,2,1),(\myvec{e}_4,2,3),(\myvec{e}_5,1,1),(\myvec{e}_6,4,2)\}\]

GN Block

The active part of a GNN is a function that updates the edge, node and global attribute values of a graph. The function is applied iteratively to produce a final output graph. We are interested in functions that are implemented using neural networks and refer to the function as a GN block. Thus, a GN block takes a graph as input and produces a graph with updated attribute values as output. The output graph has the same structure of nodes and edges as the input graph. The transformation of attribute values involves six steps, all of which depend on the structure of the graph. Most of the computations are local to a node or edge in the graph. For example, node attribute values are updated based on the values of attributes at that node, and at the edges for which the node is a receiver:

The GN block transforms the attribute values of a graph

Figure 1. The GN block transforms the attribute values of a graph.

The six steps are as follows:

(1) Update edge attribute values

The attribute value on each edge is updated based on the attribute values at the nodes on either end, and on the global attribute. As a result, the attribute value on edge \(k\) is updated as follows:

\[ \myvec{e}'_k = \phi^e(\myvec{e}_k,\myvec{v}_{r_k},\myvec{v}_{s_k},\myvec{u})\]

The following diagram shows the attributes involved in updating the attribute value on edge 6:

Step 1: updating the edge attribute value on edge 6

Figure 2. Step 1: updating the edge attribute value on edge 6.

(2) Aggregate edge attribute values at receiving nodes

At each node, \(i\), we aggregate the (updated) attribute values on edges \(E'_i\) for which this node is a receiver (i.e. all incoming edges). Because we don't want the result to depend on an ordering of incoming edges, the aggregation function \(\rho^{e->v}\) must be invariant to permutations of the edges. Typically we use the sum, mean or maximum value applying the function component-wise to the inputs:

\[E'_i=\{(\myvec{e}'_k,r_k,s_k)\}_{r_k=i,k=1:N^e}\]
\[\myvec{\bar{e}}'_i=\rho^{e->v}(E'_i)\]

In our running example, using the mean as the aggregation function, the aggregate of edge attribute values at the node 2 is given by \(\myvec{\bar{e}}'_2=(\myvec{e}'_2+\myvec{e}'_3+\myvec{e}'_4)/3\):

Step 2: updating the aggregate of edge attribute values at node 2

Figure 3. Step 2: updating the aggregate of edge attribute values at node 2.

(3) Update node attribute values

The attribute value at each node is updated based on its current value, the aggregate edge attribute value for the node, computed in step 2, and the global attribute value:

\[ \myvec{v}'_i = \phi^v(\myvec{\bar{e}}'_i,\myvec{v}_i,\myvec{u})\]

Here we show the update to the attribute value at node 2:

Step 3: updating the node attribute value at node 2

Figure 4. Step 3: updating the node attribute value at node 2.

(4) Aggregate edge attribute values globally

We aggregate the updated edge attribute values across all edges \(E'\). Again, the aggregation function needs to be invariant to permutations of the edges, so we normally choose one of sum, mean or max:

\[ E'=\{(\myvec{e}'_k,r_k,s_k)\}_{k=1:N^e} \]
\[ \myvec{\bar{e}}'=\rho^{e->u}(E') \]

Using max as the aggregation function, the aggregate of edge attribute values is:

\[\myvec{\bar{e}}'=\max(\myvec{e}'_1,\myvec{e}'_2,\myvec{e}'_3,\myvec{e}'_4,\myvec{e}'_5,\myvec{e}'_6)\]
Step 4: computing the global aggregate of updated edge attribute values

Figure 5. Step 4: computing the global aggregate of updated edge attribute values.

(5) Aggregate node attribute values globally

We aggregate the updated node attribute values across all nodes \(V'\):

\[ V'=\{\myvec{v}'_i\}_{i=1:N^v} \]
\[ \myvec{\bar{v}}'=\rho^{v->u}(V')\]

Using sum as the aggregation function, the aggregate of updated node attributes values is:

\[\myvec{\bar{v}}'=\myvec{v}'_1+\myvec{v}'_2+\myvec{v}'_3+\myvec{v}'_4\]
Step 5: computing the global aggregate of updated node attribute values

Figure 6. Step 5: computing the global aggregate of updated node attribute values.

(6) Update global attribute value

Finally, we update the global attribute value as a function of the global aggregate values across edges (step 4) and nodes (step 5), and the current global attribute value:

\[ \myvec{u}'=\phi^u(\myvec{\bar{e}}',\myvec{\bar{v}}',\myvec{u})\]
Step 6: updating the global attribute value

Figure 7. Step 6: updating the global attribute value.

Steps 1-3 can be carried out in parallel at each node or edge as appropriate.

Composing GN blocks

The three functions at the heart of the GN block, \(\phi^e,\phi^v,\phi^u\) can have many forms. Typically, it is a single linear layer but could also be a multi-layer (deep) network, or CNN when the attributes are images.

The GN block either operates locally (steps 1-3) or globally (steps 4-5). Just as for a CNN, we increase the scale of the local updates by composing (layering) GN blocks, as illustrated below, where \(M\) GN blocks are composed together, taking the graph \(G_0\) as input and producing the graph \(G_M\) as output:

Composing GN blocks together

Figure 8. Composing GN blocks together.

Physics-based example

Battaglia et al. [1] give a physics-based example to motivate the six steps of the GNN.

Imagine we have a set of rubber balls connected by springs and subject to gravity. We can represent this as a graph as below:

Balls connected by springs, represented as a graph

Figure 9. Balls connected by springs, represented as a graph.

A simulation of the balls in motion could be achieved via the six steps of a GN block as follows:

GN block step Physical process
Step 1: Update edge attribute values Forces between balls
Step 2: Aggregate edge attribute values at receiving nodes Sum of forces (net force) on each ball
Step 3: Update node attribute values Position, velocity and kinetic energy of each ball
Step 4: Aggregate edge attribute values globally Sum of per-node forces (should be zero) and potential energy store in springs
Step 5: Aggregate node attribute values globally Total kinetic energy of system
Step 6: Update global attribute value Net forces and total energy of system

References

[1] Battaglia, P.W., et al., Relational inductive biases, deep learning, and graph networks