Saving and Loading Models¶

In this notebook, I'll show you how to save and load models with PyTorch. This is important because you'll often want to load previously trained models to use in making predictions or to continue training on new data.

In [ ]:
%matplotlib inline  # Enable inline plotting in notebook
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt  # Import matplotlib for visualization

import torch  # Import PyTorch library
from torch import nn  # Import neural network module
from torch import optim
import torch.nn.functional as F  # Import functional API (F.relu, etc.)
from torchvision import datasets, transforms  # Import datasets and image transforms

import helper  # Import helper visualization functions
import fc_model  # Import pre-built fully connected model
In [ ]:
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),  # Define transform: convert to tensor
                                transforms.Normalize((0.5,), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)  # Download/load Fashion-MNIST training data
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)  # Create training data loader

# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)  # Download/load Fashion-MNIST test data
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)  # Create test data loader

Here we can see one of the images.

In [ ]:
image, label = next(iter(trainloader))  # Get one batch of images and labels
helper.imshow(image[0,:]);

Train a network¶

To make things more concise here, I moved the model architecture and training code from the last part to a file called fc_model. Importing this, we can easily create a fully-connected network with fc_model.Network, and train the network using fc_model.train. I'll use this model (once it's trained) to demonstrate how we can save and load models.

In [ ]:
# Create the network, define the criterion and optimizer

model = fc_model.Network(784, 10, [512, 256, 128])
criterion = nn.NLLLoss()  # Negative log-likelihood loss
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer
In [ ]:
fc_model.train(model, trainloader, testloader, criterion, optimizer, epochs=2)  # Train the model

Saving and loading networks¶

As you can imagine, it's impractical to train a network every time you need to use it. Instead, we can save trained networks then load them later to train more or use them for predictions.

The parameters for PyTorch networks are stored in a model's state_dict. We can see the state dict contains the weight and bias matrices for each of our layers.

In [ ]:
print("Our model: \n\n", model, '\n')
print("The state dict keys: \n\n", model.state_dict().keys())

The simplest thing to do is simply save the state dict with torch.save. For example, we can save it to a file 'checkpoint.pth'.

In [ ]:
torch.save(model.state_dict(), 'checkpoint.pth')  # Save model checkpoint

Then we can load the state dict with torch.load.

In [ ]:
state_dict = torch.load('checkpoint.pth',weights_only=True)  # Load model checkpoint
print(state_dict.keys())  # Print layer names in state dict

And to load the state dict in to the network, you do model.load_state_dict(state_dict).

In [ ]:
model.load_state_dict(state_dict)  # Load saved model weights

Seems pretty straightforward, but as usual it's a bit more complicated. Loading the state dict works only if the model architecture is exactly the same as the checkpoint architecture. If I create a model with a different architecture, this fails.

In [ ]:
# Try this
model = fc_model.Network(784, 10, [400, 200, 100])
# This will throw an error because the tensor sizes are wrong!
model.load_state_dict(state_dict)  # Load saved model weights

This means we need to rebuild the model exactly as it was when trained. Information about the model architecture needs to be saved in the checkpoint, along with the state dict. To do this, you build a dictionary with all the information you need to compeletely rebuild the model.

In [ ]:
checkpoint = {'input_size': 784,  # Create checkpoint dict
              'output_size': 10,  # Output size
              'hidden_layers': [each.out_features for each in model.hidden_layers],  # Hidden layer sizes
              'state_dict': model.state_dict()}  # Model weights

torch.save(checkpoint, 'checkpoint.pth')  # Save model checkpoint

Now the checkpoint has all the necessary information to rebuild the trained model. You can easily make that a function if you want. Similarly, we can write a function to load checkpoints.

In [ ]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath,weights_only=True)  # Load model checkpoint
    model = fc_model.Network(checkpoint['input_size'],
                             checkpoint['output_size'],
                             checkpoint['hidden_layers'])  # End of transforms list
    model.load_state_dict(checkpoint['state_dict'])  # End of transforms list
    
    return model  # Return loaded model
In [ ]:
model = load_checkpoint('checkpoint.pth')
print(model)  # Display model architecture
In [ ]: