6. PyTorch Geometric
This matters because domain libraries help, but only after you understand the tensor-level picture. Focus on the boundary between the pure PyTorch representation and what a graph library automates.
6.1. Dependency boundary
Keep optional domain libraries out of the critical path until you need them. This guarded import lets the notebook run in the lightweight checker and shows what code path would run if PyG is installed.
[ ]:
import importlib.util
import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
has_pyg = importlib.util.find_spec('torch_geometric') is not None
print('torch_geometric installed:', has_pyg)
6.2. The common data shape
PyG also uses x for node features and edge_index for directed edges. If PyG is unavailable, these tensors still work with the pure-PyTorch layers in this book.
[ ]:
x = torch.tensor([[1.0, 0.0], [0.8, 0.1], [0.0, 1.0]], device=device)
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long, device=device)
y = torch.tensor([0, 0, 1], dtype=torch.long, device=device)
if has_pyg:
from torch_geometric.data import Data
data = Data(x=x, edge_index=edge_index, y=y)
print(data)
assert data.num_nodes == 3
else:
data = {'x': x, 'edge_index': edge_index, 'y': y}
print({key: tuple(value.shape) for key, value in data.items()})
assert data['edge_index'].shape[0] == 2
6.3. Layer boundary
If PyG is installed, you can swap this pure PyTorch message-passing layer for torch_geometric.nn.GCNConv, SAGEConv, GATConv, and many others.
[ ]:
class FallbackGraphLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x, edge_index):
source, destination = edge_index
out = torch.zeros_like(x)
out.index_add_(0, destination, x[source])
degree = torch.zeros(x.shape[0], device=x.device)
degree.index_add_(0, destination, torch.ones_like(destination, dtype=x.dtype))
out = out / degree.clamp_min(1).unsqueeze(1)
return torch.relu(self.linear(out + x))
if has_pyg:
from torch_geometric.nn import GCNConv
layer = GCNConv(2, 4).to(device)
else:
layer = FallbackGraphLayer(2, 4).to(device)
node_embeddings = layer(x, edge_index)
print(node_embeddings.shape)
assert node_embeddings.shape == (3, 4)
6.4. When to install it
Install PyG when you need production graph datasets, sampled neighbor loaders, heterogeneous graphs, temporal graphs, or established layers. Keep a tensor-level mental model so Data, Batch, and MessagePassing are not magic.