2. Message Passing
This matters because message passing is the core operation underneath many GNN layers. Focus on how neighbor aggregation can be implemented directly with PyTorch tensor operations.
[ ]:
import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.tensor([
[1.0, 0.0],
[0.8, 0.2],
[0.0, 1.0],
[0.2, 0.9],
], device=device)
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3],
[1, 0, 2, 1, 3, 2],
], dtype=torch.long, device=device)
2.1. Sum and mean aggregation
Collect source-node messages and add them into destination-node slots.
[ ]:
def aggregate_sum(x, edge_index):
source, destination = edge_index
out = torch.zeros_like(x)
out.index_add_(0, destination, x[source])
return out
def aggregate_mean(x, edge_index):
source, destination = edge_index
summed = aggregate_sum(x, edge_index)
degree = torch.zeros(x.shape[0], device=x.device)
degree.index_add_(0, destination, torch.ones_like(destination, dtype=x.dtype))
return summed / degree.clamp_min(1).unsqueeze(1)
print('sum')
print(aggregate_sum(x, edge_index))
print('mean')
print(aggregate_mean(x, edge_index))
assert aggregate_sum(x, edge_index).shape == x.shape
2.2. A message-passing layer
A learnable layer can combine each node’s old state with its aggregated neighbor state.
[ ]:
class MessagePassingLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.self_linear = nn.Linear(in_features, out_features)
self.neighbor_linear = nn.Linear(in_features, out_features)
def forward(self, x, edge_index):
neighbors = aggregate_mean(x, edge_index)
return torch.relu(self.self_linear(x) + self.neighbor_linear(neighbors))
layer = MessagePassingLayer(2, 5).to(device)
new_x = layer(x, edge_index)
print(new_x.shape)
assert new_x.shape == (4, 5)
2.3. Add self loops
Self loops let a node send its current representation to itself through the same edge aggregation path as neighbor messages.
[ ]:
def add_self_loops(edge_index, num_nodes):
nodes = torch.arange(num_nodes, device=edge_index.device)
loops = torch.stack([nodes, nodes])
return torch.cat([edge_index, loops], dim=1)
edge_index_with_loops = add_self_loops(edge_index, x.shape[0])
print(edge_index_with_loops)
assert edge_index_with_loops.shape[1] == edge_index.shape[1] + x.shape[0]