10. Saving and Loading Models

[1]:
from torchvision import datasets, models, transforms
import torch.optim as optim
import torch.nn as nn
from torchvision.transforms import *
from torch.utils.data import DataLoader
import torch
import numpy as np

def train(dataloader, model, criterion, optimizer, scheduler, num_epochs=20):
    for epoch in range(num_epochs):
        optimizer.step()
        scheduler.step()
        model.train()

        running_loss = 0.0
        running_corrects = 0

        n = 0
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                loss.backward()
                optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            n += len(labels)

        epoch_loss = running_loss / float(n)
        epoch_acc = running_corrects.double() / float(n)

        print(f'epoch {epoch}/{num_epochs} : {epoch_loss:.5f}, {epoch_acc:.5f}')

np.random.seed(37)
torch.manual_seed(37)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
pretrained=True
num_classes = 3
num_epochs = 20

transform = transforms.Compose([Resize(224), ToTensor()])
image_folder = datasets.ImageFolder('./shapes/train', transform=transform)
dataloader = DataLoader(image_folder, batch_size=4, shuffle=True, num_workers=4)

model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Rprop(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

train(dataloader, model, criterion, optimizer, scheduler, num_epochs=num_epochs)
epoch 0/20 : 1.18291, 0.66667
epoch 1/20 : 1.89373, 0.56667
epoch 2/20 : 0.41106, 0.80000
epoch 3/20 : 0.09141, 0.96667
epoch 4/20 : 0.09910, 0.96667
epoch 5/20 : 0.08258, 0.96667
epoch 6/20 : 0.06175, 0.96667
epoch 7/20 : 0.34240, 0.86667
epoch 8/20 : 0.03592, 1.00000
epoch 9/20 : 0.15507, 0.93333
epoch 10/20 : 0.40221, 0.96667
epoch 11/20 : 0.07072, 0.96667
epoch 12/20 : 0.44840, 0.93333
epoch 13/20 : 0.01021, 1.00000
epoch 14/20 : 0.00262, 1.00000
epoch 15/20 : 0.00727, 1.00000
epoch 16/20 : 0.00639, 1.00000
epoch 17/20 : 0.05421, 0.96667
epoch 18/20 : 0.03431, 1.00000
epoch 19/20 : 0.00771, 1.00000

10.1. Saving

10.1.1. Saving just the model

[2]:
torch.save(model.state_dict(), './output/resnet18-model.pt')

10.1.2. Saving for later training

[3]:
torch.save({
    'model_state_dict': model.state_dict(),
    'criterion_state_dict': criterion.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict()
}, './output/resnet18-checkpoint.pt')

10.1.3. Saving to ONNX

[4]:
args = torch.randn(4, 3, 224, 224, device=device)
f = './output/alexnet.onnx'

torch.onnx.export(model, args, f, verbose=False)

10.2. Loading

10.2.1. Loading just the model

[5]:
model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

model.load_state_dict(torch.load('./output/resnet18-model.pt', map_location=device))
[5]:
<All keys matched successfully>

10.2.2. Loading for training continuation

[6]:
checkpoint = torch.load('./output/resnet18-checkpoint.pt', map_location=device)

model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Rprop(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

model.load_state_dict(checkpoint['model_state_dict'])
criterion.load_state_dict(checkpoint['criterion_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

train(dataloader, model, criterion, optimizer, scheduler, num_epochs=num_epochs)
epoch 0/20 : 0.14301, 0.96667
epoch 1/20 : 0.00260, 1.00000
epoch 2/20 : 2.75971, 0.76667
epoch 3/20 : 0.19595, 0.96667
epoch 4/20 : 0.11255, 0.96667
epoch 5/20 : 0.24430, 0.96667
epoch 6/20 : 0.49671, 0.93333
epoch 7/20 : 0.49788, 0.90000
epoch 8/20 : 0.44765, 0.86667
epoch 9/20 : 0.03913, 0.96667
epoch 10/20 : 0.01076, 1.00000
epoch 11/20 : 10.49290, 0.83333
epoch 12/20 : 0.03003, 0.96667
epoch 13/20 : 0.22657, 0.96667
epoch 14/20 : 0.00002, 1.00000
epoch 15/20 : 0.00087, 1.00000
epoch 16/20 : 0.20941, 0.96667
epoch 17/20 : 0.00000, 1.00000
epoch 18/20 : 12.09500, 0.76667
epoch 19/20 : 0.02187, 1.00000