4. Graph Classification
This matters because some graph problems require one prediction per graph rather than per node. Focus on how node embeddings are pooled into one graph representation.
[ ]:
import os
import random
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
torch.manual_seed(29)
random.seed(29)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
check_mode = os.getenv('PYTORCH_INTRO_CHECK_MODE') == '1'
4.1. A tiny graph dataset
Class 0 graphs are chains. Class 1 graphs are stars. Node features include degree and a constant feature.
[ ]:
def undirected(edges):
both = []
for source, destination in edges:
both.append((source, destination))
both.append((destination, source))
return torch.tensor(both, dtype=torch.long).t().contiguous()
def make_chain(num_nodes):
edges = [(i, i + 1) for i in range(num_nodes - 1)]
edge_index = undirected(edges)
degree = torch.zeros(num_nodes, 1)
degree.index_add_(0, edge_index[1], torch.ones(edge_index.shape[1], 1))
x = torch.cat([degree / max(1, num_nodes - 1), torch.ones(num_nodes, 1)], dim=1)
return {'x': x, 'edge_index': edge_index, 'y': torch.tensor(0, dtype=torch.long)}
def make_star(num_nodes):
edges = [(0, i) for i in range(1, num_nodes)]
edge_index = undirected(edges)
degree = torch.zeros(num_nodes, 1)
degree.index_add_(0, edge_index[1], torch.ones(edge_index.shape[1], 1))
x = torch.cat([degree / max(1, num_nodes - 1), torch.ones(num_nodes, 1)], dim=1)
return {'x': x, 'edge_index': edge_index, 'y': torch.tensor(1, dtype=torch.long)}
graphs = []
for num_nodes in range(5, 11):
graphs.append(make_chain(num_nodes))
graphs.append(make_star(num_nodes))
class GraphDataset(Dataset):
def __len__(self):
return len(graphs)
def __getitem__(self, index):
return graphs[index]
def collate_graphs(batch):
xs, edge_indices, batch_ids, labels = [], [], [], []
offset = 0
for graph_id, graph in enumerate(batch):
xs.append(graph['x'])
edge_indices.append(graph['edge_index'] + offset)
batch_ids.append(torch.full((graph['x'].shape[0],), graph_id, dtype=torch.long))
labels.append(graph['y'])
offset += graph['x'].shape[0]
return torch.cat(xs), torch.cat(edge_indices, dim=1), torch.cat(batch_ids), torch.stack(labels)
loader = DataLoader(GraphDataset(), batch_size=4, shuffle=True, collate_fn=collate_graphs)
4.2. Graph classifier
First update node features with message passing. Then average node embeddings per graph with a readout.
[ ]:
def mean_aggregate(x, edge_index):
source, destination = edge_index
out = torch.zeros_like(x)
out.index_add_(0, destination, x[source])
degree = torch.zeros(x.shape[0], device=x.device)
degree.index_add_(0, destination, torch.ones_like(destination, dtype=x.dtype))
return out / degree.clamp_min(1).unsqueeze(1)
def global_mean_pool(x, batch_ids, num_graphs):
out = torch.zeros(num_graphs, x.shape[1], device=x.device)
out.index_add_(0, batch_ids, x)
counts = torch.zeros(num_graphs, device=x.device)
counts.index_add_(0, batch_ids, torch.ones_like(batch_ids, dtype=x.dtype))
return out / counts.clamp_min(1).unsqueeze(1)
class GraphClassifier(nn.Module):
def __init__(self):
super().__init__()
self.self_linear = nn.Linear(2, 16)
self.neighbor_linear = nn.Linear(2, 16)
self.classifier = nn.Sequential(nn.ReLU(), nn.Linear(16, 2))
def forward(self, x, edge_index, batch_ids):
neighbors = mean_aggregate(x, edge_index)
node_embeddings = torch.relu(self.self_linear(x) + self.neighbor_linear(neighbors))
graph_embeddings = global_mean_pool(node_embeddings, batch_ids, int(batch_ids.max().item()) + 1)
return self.classifier(graph_embeddings)
model = GraphClassifier().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.03)
criterion = nn.CrossEntropyLoss()
epochs = 5 if check_mode else 80
for epoch in range(epochs):
correct = 0
total = 0
for x, edge_index, batch_ids, labels in loader:
x = x.to(device)
edge_index = edge_index.to(device)
batch_ids = batch_ids.to(device)
labels = labels.to(device)
optimizer.zero_grad(set_to_none=True)
logits = model(x, edge_index, batch_ids)
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
correct += (logits.argmax(dim=1) == labels).sum().item()
total += labels.numel()
print(epoch, round(correct / total, 3))
x, edge_index, batch_ids, labels = next(iter(loader))
with torch.inference_mode():
logits = model(x.to(device), edge_index.to(device), batch_ids.to(device))
assert logits.shape == (labels.shape[0], 2)