2. Message Passing

This matters because message passing is the core operation underneath many GNN layers. Focus on how neighbor aggregation can be implemented directly with PyTorch tensor operations.

[ ]:
import torch
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

x = torch.tensor([
    [1.0, 0.0],
    [0.8, 0.2],
    [0.0, 1.0],
    [0.2, 0.9],
], device=device)
edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3],
    [1, 0, 2, 1, 3, 2],
], dtype=torch.long, device=device)

2.1. Sum and mean aggregation

Collect source-node messages and add them into destination-node slots.

[ ]:
def aggregate_sum(x, edge_index):
    source, destination = edge_index
    out = torch.zeros_like(x)
    out.index_add_(0, destination, x[source])
    return out

def aggregate_mean(x, edge_index):
    source, destination = edge_index
    summed = aggregate_sum(x, edge_index)
    degree = torch.zeros(x.shape[0], device=x.device)
    degree.index_add_(0, destination, torch.ones_like(destination, dtype=x.dtype))
    return summed / degree.clamp_min(1).unsqueeze(1)

print('sum')
print(aggregate_sum(x, edge_index))
print('mean')
print(aggregate_mean(x, edge_index))
assert aggregate_sum(x, edge_index).shape == x.shape

2.2. A message-passing layer

A learnable layer can combine each node’s old state with its aggregated neighbor state.

[ ]:
class MessagePassingLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.self_linear = nn.Linear(in_features, out_features)
        self.neighbor_linear = nn.Linear(in_features, out_features)

    def forward(self, x, edge_index):
        neighbors = aggregate_mean(x, edge_index)
        return torch.relu(self.self_linear(x) + self.neighbor_linear(neighbors))

layer = MessagePassingLayer(2, 5).to(device)
new_x = layer(x, edge_index)
print(new_x.shape)
assert new_x.shape == (4, 5)

2.3. Add self loops

Self loops let a node send its current representation to itself through the same edge aggregation path as neighbor messages.

[ ]:
def add_self_loops(edge_index, num_nodes):
    nodes = torch.arange(num_nodes, device=edge_index.device)
    loops = torch.stack([nodes, nodes])
    return torch.cat([edge_index, loops], dim=1)

edge_index_with_loops = add_self_loops(edge_index, x.shape[0])
print(edge_index_with_loops)
assert edge_index_with_loops.shape[1] == edge_index.shape[1] + x.shape[0]