8. Quantization
This matters because deployment constraints are often about latency and memory, not just accuracy. Focus on the simplest quantization path first and on how to compare outputs without pretending they will be identical.
[ ]:
import copy
import torch
from torch import nn
torch.manual_seed(41)
model = nn.Sequential(nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 4)).eval()
example = torch.randn(8, 16)
float_output = model(example)
8.1. Dynamic quantization
Only the linear layers are quantized here. Activations are quantized dynamically at runtime.
[ ]:
quantized_model = torch.ao.quantization.quantize_dynamic(copy.deepcopy(model), {nn.Linear}, dtype=torch.qint8)
quantized_output = quantized_model(example)
print(type(quantized_model[0]).__name__)
print(torch.max(torch.abs(float_output - quantized_output)).item())
assert quantized_output.shape == float_output.shape
8.2. Parameter footprint
Quantized weights usually reduce serialized size even when outputs stay numerically close.
[ ]:
float_path = './output/float-model.pt'
quantized_path = './output/quantized-model.pt'
torch.save(model.state_dict(), float_path)
torch.save(quantized_model.state_dict(), quantized_path)
print(float_path, quantized_path)