4. Compiled Models

This matters because torch.compile can help, but only if you benchmark it honestly and handle environment-dependent fallbacks. Focus on validating outputs first and treating compilation as an optimization, not a correctness requirement.

[ ]:
import time

import torch
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 128), nn.GELU(), nn.Linear(128, 10)).to(device).eval()
try:
    compiled_model = torch.compile(model)
    compile_error = None
except Exception as exc:
    compiled_model = model
    compile_error = repr(exc)
x = torch.randn(256, 256, device=device)
print('compile setup error:', compile_error)

4.1. Warmup and compare outputs

[ ]:
with torch.inference_mode():
    eager_output = model(x)
    try:
        compiled_output = compiled_model(x)
        compiled_fn = compiled_model
        runtime_compile_error = None
    except Exception as exc:
        compiled_output = eager_output
        compiled_fn = model
        runtime_compile_error = repr(exc)
print('runtime compile error:', runtime_compile_error)
print(torch.max(torch.abs(eager_output - compiled_output)).item())
assert compiled_output.shape == eager_output.shape

4.2. Micro-benchmark

[ ]:
def benchmark(fn, iterations=10):
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    start = time.perf_counter()
    with torch.inference_mode():
        for _ in range(iterations):
            fn(x)
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return time.perf_counter() - start

eager_time = benchmark(model)
compiled_time = benchmark(compiled_fn)
print('eager:', round(eager_time, 4), 'compiled:', round(compiled_time, 4))