10. Saving and Loading Models
[1]:
from torchvision import datasets, models, transforms
import torch.optim as optim
import torch.nn as nn
from torchvision.transforms import *
from torch.utils.data import DataLoader
import torch
import numpy as np
def train(dataloader, model, criterion, optimizer, scheduler, num_epochs=20):
for epoch in range(num_epochs):
optimizer.step()
scheduler.step()
model.train()
running_loss = 0.0
running_corrects = 0
n = 0
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
n += len(labels)
epoch_loss = running_loss / float(n)
epoch_acc = running_corrects.double() / float(n)
print(f'epoch {epoch}/{num_epochs} : {epoch_loss:.5f}, {epoch_acc:.5f}')
np.random.seed(37)
torch.manual_seed(37)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
pretrained=True
num_classes = 3
num_epochs = 20
transform = transforms.Compose([Resize(224), ToTensor()])
image_folder = datasets.ImageFolder('./shapes/train', transform=transform)
dataloader = DataLoader(image_folder, batch_size=4, shuffle=True, num_workers=4)
model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Rprop(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
train(dataloader, model, criterion, optimizer, scheduler, num_epochs=num_epochs)
epoch 0/20 : 1.18291, 0.66667
epoch 1/20 : 1.89373, 0.56667
epoch 2/20 : 0.41106, 0.80000
epoch 3/20 : 0.09141, 0.96667
epoch 4/20 : 0.09910, 0.96667
epoch 5/20 : 0.08258, 0.96667
epoch 6/20 : 0.06175, 0.96667
epoch 7/20 : 0.34240, 0.86667
epoch 8/20 : 0.03592, 1.00000
epoch 9/20 : 0.15507, 0.93333
epoch 10/20 : 0.40221, 0.96667
epoch 11/20 : 0.07072, 0.96667
epoch 12/20 : 0.44840, 0.93333
epoch 13/20 : 0.01021, 1.00000
epoch 14/20 : 0.00262, 1.00000
epoch 15/20 : 0.00727, 1.00000
epoch 16/20 : 0.00639, 1.00000
epoch 17/20 : 0.05421, 0.96667
epoch 18/20 : 0.03431, 1.00000
epoch 19/20 : 0.00771, 1.00000
10.1. Saving
10.1.1. Saving just the model
[2]:
torch.save(model.state_dict(), './output/resnet18-model.pt')
10.1.2. Saving for later training
[3]:
torch.save({
'model_state_dict': model.state_dict(),
'criterion_state_dict': criterion.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict()
}, './output/resnet18-checkpoint.pt')
10.1.3. Saving to ONNX
[4]:
args = torch.randn(4, 3, 224, 224, device=device)
f = './output/alexnet.onnx'
torch.onnx.export(model, args, f, verbose=False)
10.2. Loading
10.2.1. Loading just the model
[5]:
model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
model.load_state_dict(torch.load('./output/resnet18-model.pt', map_location=device))
[5]:
<All keys matched successfully>
10.2.2. Loading for training continuation
[6]:
checkpoint = torch.load('./output/resnet18-checkpoint.pt', map_location=device)
model = models.resnet18(pretrained=pretrained)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Rprop(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
model.load_state_dict(checkpoint['model_state_dict'])
criterion.load_state_dict(checkpoint['criterion_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
train(dataloader, model, criterion, optimizer, scheduler, num_epochs=num_epochs)
epoch 0/20 : 0.14301, 0.96667
epoch 1/20 : 0.00260, 1.00000
epoch 2/20 : 2.75971, 0.76667
epoch 3/20 : 0.19595, 0.96667
epoch 4/20 : 0.11255, 0.96667
epoch 5/20 : 0.24430, 0.96667
epoch 6/20 : 0.49671, 0.93333
epoch 7/20 : 0.49788, 0.90000
epoch 8/20 : 0.44765, 0.86667
epoch 9/20 : 0.03913, 0.96667
epoch 10/20 : 0.01076, 1.00000
epoch 11/20 : 10.49290, 0.83333
epoch 12/20 : 0.03003, 0.96667
epoch 13/20 : 0.22657, 0.96667
epoch 14/20 : 0.00002, 1.00000
epoch 15/20 : 0.00087, 1.00000
epoch 16/20 : 0.20941, 0.96667
epoch 17/20 : 0.00000, 1.00000
epoch 18/20 : 12.09500, 0.76667
epoch 19/20 : 0.02187, 1.00000