10. Serving API

This matters because the serving path is where training assumptions become production bugs. Focus on preserving preprocessing contracts, making batch behavior explicit, and keeping the inference boundary simple.

[ ]:
import torch
from torch import nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.Sequential(nn.Linear(3, 8), nn.ReLU(), nn.Linear(8, 2)).to(device).eval()

def preprocess(payload):
    features = torch.tensor(payload['features'], dtype=torch.float32, device=device)
    if features.ndim == 1:
        features = features.unsqueeze(0)
    return features

def postprocess(logits):
    probabilities = torch.softmax(logits, dim=-1)
    predictions = probabilities.argmax(dim=-1)
    return {
        'predictions': predictions.cpu().tolist(),
        'probabilities': probabilities.cpu().tolist(),
    }

def predict(payload):
    features = preprocess(payload)
    with torch.inference_mode():
        logits = model(features)
    return postprocess(logits)

10.1. Single request and batch request

[ ]:
single = {'features': [0.4, -0.2, 1.1]}
batched = {'features': [[0.4, -0.2, 1.1], [1.0, 0.3, -0.4]]}
print(predict(single))
print(predict(batched))
assert len(predict(single)['predictions']) == 1
assert len(predict(batched)['predictions']) == 2

10.2. API shape

The same predict function can sit behind FastAPI, Flask, gRPC, or a batch job. The important part is that the serving path reuses the exact preprocessing and postprocessing logic.