8. Quantization

This matters because deployment constraints are often about latency and memory, not just accuracy. Focus on the simplest quantization path first and on how to compare outputs without pretending they will be identical.

[ ]:
import copy

import torch
from torch import nn

torch.manual_seed(41)

model = nn.Sequential(nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 4)).eval()
example = torch.randn(8, 16)
float_output = model(example)

8.1. Dynamic quantization

Only the linear layers are quantized here. Activations are quantized dynamically at runtime.

[ ]:
quantized_model = torch.ao.quantization.quantize_dynamic(copy.deepcopy(model), {nn.Linear}, dtype=torch.qint8)
quantized_output = quantized_model(example)
print(type(quantized_model[0]).__name__)
print(torch.max(torch.abs(float_output - quantized_output)).item())
assert quantized_output.shape == float_output.shape

8.2. Parameter footprint

Quantized weights usually reduce serialized size even when outputs stay numerically close.

[ ]:
float_path = './output/float-model.pt'
quantized_path = './output/quantized-model.pt'
torch.save(model.state_dict(), float_path)
torch.save(quantized_model.state_dict(), quantized_path)
print(float_path, quantized_path)