1. Graph Data

This matters because graph models become confusing when the data representation is treated as magic. Focus on how node features, edge lists, labels, and graph batches are just structured tensors.

1.1. A graph as tensors

edge_index uses shape [2, num_edges]. Row 0 stores source nodes and row 1 stores destination nodes.

[ ]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

node_features = torch.tensor([
    [1.0, 0.0, 0.2],
    [0.9, 0.1, 0.1],
    [0.0, 1.0, 0.3],
    [0.1, 0.8, 0.4],
    [0.2, 0.2, 1.0],
], dtype=torch.float32)

edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3, 3, 4],
    [1, 0, 2, 1, 3, 2, 4, 3],
], dtype=torch.long)
node_labels = torch.tensor([0, 0, 1, 1, 2], dtype=torch.long)

node_features = node_features.to(device)
edge_index = edge_index.to(device)
node_labels = node_labels.to(device)

print('nodes:', node_features.shape)
print('edges:', edge_index.shape)
assert edge_index.shape[0] == 2

1.2. Edge list to adjacency

Dense adjacency matrices are easy to inspect. Edge lists are usually cheaper for sparse graphs.

[ ]:
def to_dense_adjacency(edge_index, num_nodes):
    adjacency = torch.zeros((num_nodes, num_nodes), dtype=torch.float32, device=edge_index.device)
    source, destination = edge_index
    adjacency[destination, source] = 1.0
    return adjacency

adjacency = to_dense_adjacency(edge_index, num_nodes=node_features.shape[0])
print(adjacency.cpu())
assert adjacency.sum().item() == edge_index.shape[1]

1.3. Batch multiple graphs

Concatenate node features, offset the second graph’s node ids, concatenate edges, and keep a batch vector telling which graph each node came from.

[ ]:
graph_a = {
    'x': torch.randn(3, 4),
    'edge_index': torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]),
    'y': torch.tensor([0]),
}
graph_b = {
    'x': torch.randn(2, 4),
    'edge_index': torch.tensor([[0, 1], [1, 0]]),
    'y': torch.tensor([1]),
}

def batch_graphs(graphs):
    xs = []
    edge_indices = []
    batch = []
    labels = []
    node_offset = 0
    for graph_id, graph in enumerate(graphs):
        xs.append(graph['x'])
        edge_indices.append(graph['edge_index'] + node_offset)
        batch.append(torch.full((graph['x'].shape[0],), graph_id, dtype=torch.long))
        labels.append(graph['y'])
        node_offset += graph['x'].shape[0]
    return {
        'x': torch.cat(xs),
        'edge_index': torch.cat(edge_indices, dim=1),
        'batch': torch.cat(batch),
        'y': torch.cat(labels),
    }

batched = batch_graphs([graph_a, graph_b])
print({key: tuple(value.shape) for key, value in batched.items()})
assert batched['batch'].tolist() == [0, 0, 0, 1, 1]