3. Graph Convolution
This matters because a graph convolution is just a normalized form of message passing plus a projection. Focus on the normalization and on how a node classifier is trained from the resulting features.
[ ]:
import os
import torch
from torch import nn
import torch.nn.functional as F
torch.manual_seed(23)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
3.1. Synthetic node classification graph
Two communities have different node-feature centers. The graph has many same-class edges and a few bridge edges.
[ ]:
num_nodes_per_class = 10
class0 = torch.randn(num_nodes_per_class, 4) * 0.15 + torch.tensor([1.0, 0.0, 0.0, 0.0])
class1 = torch.randn(num_nodes_per_class, 4) * 0.15 + torch.tensor([0.0, 1.0, 0.0, 0.0])
x = torch.cat([class0, class1]).to(device)
y = torch.cat([
torch.zeros(num_nodes_per_class, dtype=torch.long),
torch.ones(num_nodes_per_class, dtype=torch.long),
]).to(device)
edges = []
for start in [0, num_nodes_per_class]:
for node in range(start, start + num_nodes_per_class):
neighbor = start + ((node - start + 1) % num_nodes_per_class)
edges.append((node, neighbor))
edges.append((neighbor, node))
edges += [(4, 14), (14, 4), (7, 17), (17, 7)]
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous().to(device)
train_mask = torch.zeros(y.shape[0], dtype=torch.bool, device=device)
train_mask[[0, 1, 10, 11]] = True
print('x:', x.shape, 'edge_index:', edge_index.shape)
assert edge_index.shape[0] == 2
3.2. GCN layer
This layer uses symmetric degree normalization: each message is scaled by 1 / sqrt(degree[source] * degree[destination]).
[ ]:
def add_self_loops(edge_index, num_nodes):
nodes = torch.arange(num_nodes, device=edge_index.device)
loops = torch.stack([nodes, nodes])
return torch.cat([edge_index, loops], dim=1)
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=False)
def forward(self, x, edge_index):
num_nodes = x.shape[0]
edge_index = add_self_loops(edge_index, num_nodes)
source, destination = edge_index
degree = torch.zeros(num_nodes, device=x.device)
degree.index_add_(0, destination, torch.ones_like(destination, dtype=x.dtype))
norm = degree[source].clamp_min(1).rsqrt() * degree[destination].clamp_min(1).rsqrt()
messages = self.linear(x)[source] * norm.unsqueeze(1)
out = torch.zeros(num_nodes, messages.shape[1], device=x.device)
out.index_add_(0, destination, messages)
return out
class GCN(nn.Module):
def __init__(self, in_features, hidden_features, num_classes):
super().__init__()
self.conv1 = GCNLayer(in_features, hidden_features)
self.conv2 = GCNLayer(hidden_features, num_classes)
def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
return self.conv2(x, edge_index)
[ ]:
model = GCN(in_features=4, hidden_features=12, num_classes=2).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.05)
criterion = nn.CrossEntropyLoss()
steps = 20 if check_mode else 200
for step in range(steps):
model.train()
optimizer.zero_grad(set_to_none=True)
logits = model(x, edge_index)
loss = criterion(logits[train_mask], y[train_mask])
loss.backward()
optimizer.step()
if step in {0, steps - 1}:
accuracy = (logits.argmax(dim=1) == y).float().mean().item()
print(step, round(loss.item(), 4), 'all-node accuracy', round(accuracy, 3))
with torch.inference_mode():
logits = model(x, edge_index)
assert logits.shape == (y.shape[0], 2)
assert torch.isfinite(logits).all()