4. Compiled Models
This matters because torch.compile can help, but only if you benchmark it honestly and handle environment-dependent fallbacks. Focus on validating outputs first and treating compilation as an optimization, not a correctness requirement.
[ ]:
import time
import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 128), nn.GELU(), nn.Linear(128, 10)).to(device).eval()
try:
compiled_model = torch.compile(model)
compile_error = None
except Exception as exc:
compiled_model = model
compile_error = repr(exc)
x = torch.randn(256, 256, device=device)
print('compile setup error:', compile_error)
4.1. Warmup and compare outputs
[ ]:
with torch.inference_mode():
eager_output = model(x)
try:
compiled_output = compiled_model(x)
compiled_fn = compiled_model
runtime_compile_error = None
except Exception as exc:
compiled_output = eager_output
compiled_fn = model
runtime_compile_error = repr(exc)
print('runtime compile error:', runtime_compile_error)
print(torch.max(torch.abs(eager_output - compiled_output)).item())
assert compiled_output.shape == eager_output.shape
4.2. Micro-benchmark
[ ]:
def benchmark(fn, iterations=10):
if torch.cuda.is_available():
torch.cuda.synchronize()
start = time.perf_counter()
with torch.inference_mode():
for _ in range(iterations):
fn(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter() - start
eager_time = benchmark(model)
compiled_time = benchmark(compiled_fn)
print('eager:', round(eager_time, 4), 'compiled:', round(compiled_time, 4))