6. PyTorch Geometric

This matters because domain libraries help, but only after you understand the tensor-level picture. Focus on the boundary between the pure PyTorch representation and what a graph library automates.

6.1. Dependency boundary

Keep optional domain libraries out of the critical path until you need them. This guarded import lets the notebook run in the lightweight checker and shows what code path would run if PyG is installed.

[ ]:
import importlib.util

import torch
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
has_pyg = importlib.util.find_spec('torch_geometric') is not None
print('torch_geometric installed:', has_pyg)

6.2. The common data shape

PyG also uses x for node features and edge_index for directed edges. If PyG is unavailable, these tensors still work with the pure-PyTorch layers in this book.

[ ]:
x = torch.tensor([[1.0, 0.0], [0.8, 0.1], [0.0, 1.0]], device=device)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long, device=device)
y = torch.tensor([0, 0, 1], dtype=torch.long, device=device)

if has_pyg:
    from torch_geometric.data import Data
    data = Data(x=x, edge_index=edge_index, y=y)
    print(data)
    assert data.num_nodes == 3
else:
    data = {'x': x, 'edge_index': edge_index, 'y': y}
    print({key: tuple(value.shape) for key, value in data.items()})
    assert data['edge_index'].shape[0] == 2

6.3. Layer boundary

If PyG is installed, you can swap this pure PyTorch message-passing layer for torch_geometric.nn.GCNConv, SAGEConv, GATConv, and many others.

[ ]:
class FallbackGraphLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)

    def forward(self, x, edge_index):
        source, destination = edge_index
        out = torch.zeros_like(x)
        out.index_add_(0, destination, x[source])
        degree = torch.zeros(x.shape[0], device=x.device)
        degree.index_add_(0, destination, torch.ones_like(destination, dtype=x.dtype))
        out = out / degree.clamp_min(1).unsqueeze(1)
        return torch.relu(self.linear(out + x))

if has_pyg:
    from torch_geometric.nn import GCNConv
    layer = GCNConv(2, 4).to(device)
else:
    layer = FallbackGraphLayer(2, 4).to(device)

node_embeddings = layer(x, edge_index)
print(node_embeddings.shape)
assert node_embeddings.shape == (3, 4)

6.4. When to install it

Install PyG when you need production graph datasets, sampled neighbor loaders, heterogeneous graphs, temporal graphs, or established layers. Keep a tensor-level mental model so Data, Batch, and MessagePassing are not magic.