1. Graph Data
This matters because graph models become confusing when the data representation is treated as magic. Focus on how node features, edge lists, labels, and graph batches are just structured tensors.
1.1. A graph as tensors
edge_index uses shape [2, num_edges]. Row 0 stores source nodes and row 1 stores destination nodes.
[ ]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
node_features = torch.tensor([
[1.0, 0.0, 0.2],
[0.9, 0.1, 0.1],
[0.0, 1.0, 0.3],
[0.1, 0.8, 0.4],
[0.2, 0.2, 1.0],
], dtype=torch.float32)
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3],
], dtype=torch.long)
node_labels = torch.tensor([0, 0, 1, 1, 2], dtype=torch.long)
node_features = node_features.to(device)
edge_index = edge_index.to(device)
node_labels = node_labels.to(device)
print('nodes:', node_features.shape)
print('edges:', edge_index.shape)
assert edge_index.shape[0] == 2
1.2. Edge list to adjacency
Dense adjacency matrices are easy to inspect. Edge lists are usually cheaper for sparse graphs.
[ ]:
def to_dense_adjacency(edge_index, num_nodes):
adjacency = torch.zeros((num_nodes, num_nodes), dtype=torch.float32, device=edge_index.device)
source, destination = edge_index
adjacency[destination, source] = 1.0
return adjacency
adjacency = to_dense_adjacency(edge_index, num_nodes=node_features.shape[0])
print(adjacency.cpu())
assert adjacency.sum().item() == edge_index.shape[1]
1.3. Batch multiple graphs
Concatenate node features, offset the second graph’s node ids, concatenate edges, and keep a batch vector telling which graph each node came from.
[ ]:
graph_a = {
'x': torch.randn(3, 4),
'edge_index': torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]),
'y': torch.tensor([0]),
}
graph_b = {
'x': torch.randn(2, 4),
'edge_index': torch.tensor([[0, 1], [1, 0]]),
'y': torch.tensor([1]),
}
def batch_graphs(graphs):
xs = []
edge_indices = []
batch = []
labels = []
node_offset = 0
for graph_id, graph in enumerate(graphs):
xs.append(graph['x'])
edge_indices.append(graph['edge_index'] + node_offset)
batch.append(torch.full((graph['x'].shape[0],), graph_id, dtype=torch.long))
labels.append(graph['y'])
node_offset += graph['x'].shape[0]
return {
'x': torch.cat(xs),
'edge_index': torch.cat(edge_indices, dim=1),
'batch': torch.cat(batch),
'y': torch.cat(labels),
}
batched = batch_graphs([graph_a, graph_b])
print({key: tuple(value.shape) for key, value in batched.items()})
assert batched['batch'].tolist() == [0, 0, 0, 1, 1]