3. Graph Convolution

This matters because a graph convolution is just a normalized form of message passing plus a projection. Focus on the normalization and on how a node classifier is trained from the resulting features.

[ ]:
import os

import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(23)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'

3.1. Synthetic node classification graph

Two communities have different node-feature centers. The graph has many same-class edges and a few bridge edges.

[ ]:
num_nodes_per_class = 10
class0 = torch.randn(num_nodes_per_class, 4) * 0.15 + torch.tensor([1.0, 0.0, 0.0, 0.0])
class1 = torch.randn(num_nodes_per_class, 4) * 0.15 + torch.tensor([0.0, 1.0, 0.0, 0.0])
x = torch.cat([class0, class1]).to(device)
y = torch.cat([
    torch.zeros(num_nodes_per_class, dtype=torch.long),
    torch.ones(num_nodes_per_class, dtype=torch.long),
]).to(device)

edges = []
for start in [0, num_nodes_per_class]:
    for node in range(start, start + num_nodes_per_class):
        neighbor = start + ((node - start + 1) % num_nodes_per_class)
        edges.append((node, neighbor))
        edges.append((neighbor, node))
edges += [(4, 14), (14, 4), (7, 17), (17, 7)]
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(device)
train_mask = torch.zeros(y.shape[0], dtype=torch.bool, device=device)
train_mask[[0, 1, 10, 11]] = True

print('x:', x.shape, 'edge_index:', edge_index.shape)
assert edge_index.shape[0] == 2

3.2. GCN layer

This layer uses symmetric degree normalization: each message is scaled by 1 / sqrt(degree[source] * degree[destination]).

[ ]:
def add_self_loops(edge_index, num_nodes):
    nodes = torch.arange(num_nodes, device=edge_index.device)
    loops = torch.stack([nodes, nodes])
    return torch.cat([edge_index, loops], dim=1)

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=False)

    def forward(self, x, edge_index):
        num_nodes = x.shape[0]
        edge_index = add_self_loops(edge_index, num_nodes)
        source, destination = edge_index
        degree = torch.zeros(num_nodes, device=x.device)
        degree.index_add_(0, destination, torch.ones_like(destination, dtype=x.dtype))
        norm = degree[source].clamp_min(1).rsqrt() * degree[destination].clamp_min(1).rsqrt()

        messages = self.linear(x)[source] * norm.unsqueeze(1)
        out = torch.zeros(num_nodes, messages.shape[1], device=x.device)
        out.index_add_(0, destination, messages)
        return out

class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, num_classes):
        super().__init__()
        self.conv1 = GCNLayer(in_features, hidden_features)
        self.conv2 = GCNLayer(hidden_features, num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)
[ ]:
model = GCN(in_features=4, hidden_features=12, num_classes=2).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.05)
criterion = nn.CrossEntropyLoss()

steps = 20 if check_mode else 200
for step in range(steps):
    model.train()
    optimizer.zero_grad(set_to_none=True)
    logits = model(x, edge_index)
    loss = criterion(logits[train_mask], y[train_mask])
    loss.backward()
    optimizer.step()
    if step in {0, steps - 1}:
        accuracy = (logits.argmax(dim=1) == y).float().mean().item()
        print(step, round(loss.item(), 4), 'all-node accuracy', round(accuracy, 3))

with torch.inference_mode():
    logits = model(x, edge_index)
assert logits.shape == (y.shape[0], 2)
assert torch.isfinite(logits).all()