5. Heterogeneous Graphs

This matters because many real graphs have several relation types and several node types. Focus on the bookkeeping that keeps those relations separate before aggregation.

5.1. Separate tensors by node type and edge type

Pure PyTorch keeps the structure visible: one feature matrix per node type and one edge list per relation type.

[ ]:
import torch
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

user_x = torch.randn(3, 4, device=device)
item_x = torch.randn(4, 4, device=device)
click_edge_index = torch.tensor([[0, 1, 2, 2], [0, 1, 2, 3]], dtype=torch.long, device=device)
similar_edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.long, device=device)
print(user_x.shape, item_x.shape)
assert click_edge_index.shape[0] == 2

5.2. Relation-specific projections

Different relations can use different linear layers before aggregation.

[ ]:
def aggregate_messages(source_x, edge_index, num_destination_nodes):
    source, destination = edge_index
    out = torch.zeros(num_destination_nodes, source_x.shape[1], device=source_x.device)
    out.index_add_(0, destination, source_x[source])
    counts = torch.zeros(num_destination_nodes, device=source_x.device)
    counts.index_add_(0, destination, torch.ones_like(destination, dtype=source_x.dtype))
    return out / counts.clamp_min(1).unsqueeze(1)

user_to_item = nn.Linear(4, 6).to(device)
item_to_item = nn.Linear(4, 6).to(device)

click_messages = aggregate_messages(user_to_item(user_x), click_edge_index, item_x.shape[0])
similar_messages = aggregate_messages(item_to_item(item_x), similar_edge_index, item_x.shape[0])
item_update = torch.relu(click_messages + similar_messages)
print(item_update.shape)
assert item_update.shape == (4, 6)