4. Graph Classification

This matters because some graph problems require one prediction per graph rather than per node. Focus on how node embeddings are pooled into one graph representation.

[ ]:
import os
import random

import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

torch.manual_seed(29)
random.seed(29)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'

4.1. A tiny graph dataset

Class 0 graphs are chains. Class 1 graphs are stars. Node features include degree and a constant feature.

[ ]:
def undirected(edges):
    both = []
    for source, destination in edges:
        both.append((source, destination))
        both.append((destination, source))
    return torch.tensor(both, dtype=torch.long).t().contiguous()

def make_chain(num_nodes):
    edges = [(i, i + 1) for i in range(num_nodes - 1)]
    edge_index = undirected(edges)
    degree = torch.zeros(num_nodes, 1)
    degree.index_add_(0, edge_index[1], torch.ones(edge_index.shape[1], 1))
    x = torch.cat([degree / max(1, num_nodes - 1), torch.ones(num_nodes, 1)], dim=1)
    return {'x': x, 'edge_index': edge_index, 'y': torch.tensor(0, dtype=torch.long)}

def make_star(num_nodes):
    edges = [(0, i) for i in range(1, num_nodes)]
    edge_index = undirected(edges)
    degree = torch.zeros(num_nodes, 1)
    degree.index_add_(0, edge_index[1], torch.ones(edge_index.shape[1], 1))
    x = torch.cat([degree / max(1, num_nodes - 1), torch.ones(num_nodes, 1)], dim=1)
    return {'x': x, 'edge_index': edge_index, 'y': torch.tensor(1, dtype=torch.long)}

graphs = []
for num_nodes in range(5, 11):
    graphs.append(make_chain(num_nodes))
    graphs.append(make_star(num_nodes))

class GraphDataset(Dataset):
    def __len__(self):
        return len(graphs)

    def __getitem__(self, index):
        return graphs[index]

def collate_graphs(batch):
    xs, edge_indices, batch_ids, labels = [], [], [], []
    offset = 0
    for graph_id, graph in enumerate(batch):
        xs.append(graph['x'])
        edge_indices.append(graph['edge_index'] + offset)
        batch_ids.append(torch.full((graph['x'].shape[0],), graph_id, dtype=torch.long))
        labels.append(graph['y'])
        offset += graph['x'].shape[0]
    return torch.cat(xs), torch.cat(edge_indices, dim=1), torch.cat(batch_ids), torch.stack(labels)

loader = DataLoader(GraphDataset(), batch_size=4, shuffle=True, collate_fn=collate_graphs)

4.2. Graph classifier

First update node features with message passing. Then average node embeddings per graph with a readout.

[ ]:
def mean_aggregate(x, edge_index):
    source, destination = edge_index
    out = torch.zeros_like(x)
    out.index_add_(0, destination, x[source])
    degree = torch.zeros(x.shape[0], device=x.device)
    degree.index_add_(0, destination, torch.ones_like(destination, dtype=x.dtype))
    return out / degree.clamp_min(1).unsqueeze(1)

def global_mean_pool(x, batch_ids, num_graphs):
    out = torch.zeros(num_graphs, x.shape[1], device=x.device)
    out.index_add_(0, batch_ids, x)
    counts = torch.zeros(num_graphs, device=x.device)
    counts.index_add_(0, batch_ids, torch.ones_like(batch_ids, dtype=x.dtype))
    return out / counts.clamp_min(1).unsqueeze(1)

class GraphClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.self_linear = nn.Linear(2, 16)
        self.neighbor_linear = nn.Linear(2, 16)
        self.classifier = nn.Sequential(nn.ReLU(), nn.Linear(16, 2))

    def forward(self, x, edge_index, batch_ids):
        neighbors = mean_aggregate(x, edge_index)
        node_embeddings = torch.relu(self.self_linear(x) + self.neighbor_linear(neighbors))
        graph_embeddings = global_mean_pool(node_embeddings, batch_ids, int(batch_ids.max().item()) + 1)
        return self.classifier(graph_embeddings)

model = GraphClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.03)
criterion = nn.CrossEntropyLoss()

epochs = 5 if check_mode else 80
for epoch in range(epochs):
    correct = 0
    total = 0
    for x, edge_index, batch_ids, labels in loader:
        x = x.to(device)
        edge_index = edge_index.to(device)
        batch_ids = batch_ids.to(device)
        labels = labels.to(device)
        optimizer.zero_grad(set_to_none=True)
        logits = model(x, edge_index, batch_ids)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        correct += (logits.argmax(dim=1) == labels).sum().item()
        total += labels.numel()
    print(epoch, round(correct / total, 3))

x, edge_index, batch_ids, labels = next(iter(loader))
with torch.inference_mode():
    logits = model(x.to(device), edge_index.to(device), batch_ids.to(device))
assert logits.shape == (labels.shape[0], 2)