5. Heterogeneous Graphs
This matters because many real graphs have several relation types and several node types. Focus on the bookkeeping that keeps those relations separate before aggregation.
5.1. Separate tensors by node type and edge type
Pure PyTorch keeps the structure visible: one feature matrix per node type and one edge list per relation type.
[ ]:
import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
user_x = torch.randn(3, 4, device=device)
item_x = torch.randn(4, 4, device=device)
click_edge_index = torch.tensor([[0, 1, 2, 2], [0, 1, 2, 3]], dtype=torch.long, device=device)
similar_edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]], dtype=torch.long, device=device)
print(user_x.shape, item_x.shape)
assert click_edge_index.shape[0] == 2
5.2. Relation-specific projections
Different relations can use different linear layers before aggregation.
[ ]:
def aggregate_messages(source_x, edge_index, num_destination_nodes):
source, destination = edge_index
out = torch.zeros(num_destination_nodes, source_x.shape[1], device=source_x.device)
out.index_add_(0, destination, source_x[source])
counts = torch.zeros(num_destination_nodes, device=source_x.device)
counts.index_add_(0, destination, torch.ones_like(destination, dtype=source_x.dtype))
return out / counts.clamp_min(1).unsqueeze(1)
user_to_item = nn.Linear(4, 6).to(device)
item_to_item = nn.Linear(4, 6).to(device)
click_messages = aggregate_messages(user_to_item(user_x), click_edge_index, item_x.shape[0])
similar_messages = aggregate_messages(item_to_item(item_x), similar_edge_index, item_x.shape[0])
item_update = torch.relu(click_messages + similar_messages)
print(item_update.shape)
assert item_update.shape == (4, 6)