10. Serving API
This matters because the serving path is where training assumptions become production bugs. Focus on preserving preprocessing contracts, making batch behavior explicit, and keeping the inference boundary simple.
[ ]:
import torch
from torch import nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(nn.Linear(3, 8), nn.ReLU(), nn.Linear(8, 2)).to(device).eval()
def preprocess(payload):
features = torch.tensor(payload['features'], dtype=torch.float32, device=device)
if features.ndim == 1:
features = features.unsqueeze(0)
return features
def postprocess(logits):
probabilities = torch.softmax(logits, dim=-1)
predictions = probabilities.argmax(dim=-1)
return {
'predictions': predictions.cpu().tolist(),
'probabilities': probabilities.cpu().tolist(),
}
def predict(payload):
features = preprocess(payload)
with torch.inference_mode():
logits = model(features)
return postprocess(logits)
10.1. Single request and batch request
[ ]:
single = {'features': [0.4, -0.2, 1.1]}
batched = {'features': [[0.4, -0.2, 1.1], [1.0, 0.3, -0.4]]}
print(predict(single))
print(predict(batched))
assert len(predict(single)['predictions']) == 1
assert len(predict(batched)['predictions']) == 2
10.2. API shape
The same predict function can sit behind FastAPI, Flask, gRPC, or a batch job. The important part is that the serving path reuses the exact preprocessing and postprocessing logic.