-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy pathpytorch_loadsave.py
57 lines (42 loc) · 1.96 KB
/
pytorch_loadsave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
Small code example of how to save and load checkpoint of a model.
This example doesn't perform any training, so it would be quite useless.
In practice you would save the model as you train, and then load before
continuining training at another point.
Video explanation of code & how to save and load model: https://youtu.be/g6kQl_EFn84
Got any questions leave a comment on youtube :)
Coded by Aladdin Persson <aladdin dot person at hotmail dot com>
* 2020-04-07 Initial programming
* 2022-12-16 Updated with more detailed comments, and checked code still functions as intended.
"""
# Imports
import torch
import torchvision
import torch.nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F # All functions that don't have any parameters
from torch.utils.data import (
DataLoader,
) # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms # Transformations we can perform on our dataset
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model, optimizer):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
def main():
# Initialize network
model = torchvision.models.vgg16(
weights=None
) # pretrained=False deprecated, use weights instead
optimizer = optim.Adam(model.parameters())
checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
# Try save checkpoint
save_checkpoint(checkpoint)
# Try load checkpoint
load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)
if __name__ == "__main__":
main()