Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Update lottery ticket example (#2559)
Browse files Browse the repository at this point in the history
  • Loading branch information
chicm-ms authored Jun 29, 2020
1 parent b82bad0 commit e60e183
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 6 deletions.
60 changes: 58 additions & 2 deletions examples/model_compress/lottery_torch_mnist_fc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import argparse
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -53,6 +55,31 @@ def test(model, test_loader, criterion):


if __name__ == '__main__':
"""
THE LOTTERY TICKET HYPOTHESIS: FINDING SPARSE, TRAINABLE NEURAL NETWORKS (https://arxiv.org/pdf/1803.03635.pdf)
The Lottery Ticket Hypothesis. A randomly-initialized, dense neural network contains a subnetwork that is
initialized such that—when trained in isolation—it can match the test accuracy of the original network after
training for at most the same number of iterations.
Identifying winning tickets. We identify a winning ticket by training a network and pruning its
smallest-magnitude weights. The remaining, unpruned connections constitute the architecture of the
winning ticket. Unique to our work, each unpruned connection’s value is then reset to its initialization
from original network before it was trained. This forms our central experiment:
1. Randomly initialize a neural network f(x; θ0) (where θ0 ∼ Dθ).
2. Train the network for j iterations, arriving at parameters θj .
3. Prune p% of the parameters in θj , creating a mask m.
4. Reset the remaining parameters to their values in θ0, creating the winning ticket f(x; m θ0).
As described, this pruning approach is one-shot: the network is trained once, p% of weights are
pruned, and the surviving weights are reset. However, in this paper, we focus on iterative pruning,
which repeatedly trains, prunes, and resets the network over n rounds; each round prunes p**(1/n) % of
the weights that survive the previous round. Our results show that iterative pruning finds winning tickets
that match the accuracy of the original network at smaller sizes than does one-shot pruning.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--train_epochs", type=int, default=10, help="training epochs")
args = parser.parse_args()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
Expand All @@ -63,6 +90,20 @@ def test(model, test_loader, criterion):
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
criterion = nn.CrossEntropyLoss()

# Record the random intialized model weights
orig_state = copy.deepcopy(model.state_dict())

# train the model to get unpruned metrics
for epoch in range(args.train_epochs):
train(model, train_loader, optimizer, criterion)
orig_accuracy = test(model, test_loader, criterion)
print('unpruned model accuracy: {}'.format(orig_accuracy))

# reset model weights and optimizer for pruning
model.load_state_dict(orig_state)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)

# Prune the model to find a winning ticket
configure_list = [{
'prune_iterations': 5,
'sparsity': 0.96,
Expand All @@ -71,14 +112,29 @@ def test(model, test_loader, criterion):
pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress()

best_accuracy = 0.
best_state_dict = None

for i in pruner.get_prune_iterations():
pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(10):
for epoch in range(args.train_epochs):
loss = train(model, train_loader, optimizer, criterion)
accuracy = test(model, test_loader, criterion)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
if accuracy > best_accuracy:
best_accuracy = accuracy
# state dict of weights and masks
best_state_dict = copy.deepcopy(model.state_dict())
print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')

if best_accuracy > orig_accuracy:
# load weights and masks
pruner.bound_model.load_state_dict(best_state_dict)
# reset weights to original untrained model and keep masks unchanged to export winning ticket
pruner.load_model_state_dict(orig_state)
pruner.export_model('model_winning_ticket.pth', 'mask_winning_ticket.pth')
print('winning ticket has been saved: model_winning_ticket.pth, mask_winning_ticket.pth')
else:
print('winning ticket is not found in this run, you can run it again.')
4 changes: 2 additions & 2 deletions src/sdk/pynni/nni/compression/torch/pruning/lottery_ticket.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def _calc_sparsity(self, sparsity):
return max(1 - curr_keep_ratio, 0)

def _calc_mask(self, wrapper, sparsity):
weight = wrapper.weight.data
weight = wrapper.module.weight.data
if self.curr_prune_iteration == 0:
mask = {'weight_mask': torch.ones(weight.shape).type_as(weight)}
else:
curr_sparsity = self._calc_sparsity(sparsity)
mask = self.masker.calc_mask(wrapper, curr_sparsity)
mask = self.masker.calc_mask(sparsity=curr_sparsity, wrapper=wrapper)
return mask

def calc_mask(self, wrapper, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions test/scripts/model_compression.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ python3 model_prune_torch.py --pruner_name agp --pretrain_epochs 1 --prune_epoch
echo 'testing mean_activation pruning'
python3 model_prune_torch.py --pruner_name mean_activation --pretrain_epochs 1 --prune_epochs 1

#echo "testing lottery ticket pruning..."
#python3 lottery_torch_mnist_fc.py
echo "testing lottery ticket pruning..."
python3 lottery_torch_mnist_fc.py --train_epochs 1

echo ""
echo "===========================Testing: quantizers==========================="
Expand Down

0 comments on commit e60e183

Please sign in to comment.