Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
update exclude example
Browse files Browse the repository at this point in the history
  • Loading branch information
J-shang committed Aug 9, 2021
1 parent 299d01d commit 1c98049
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions examples/model_compress/pruning/basic_pruners_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
sys.path.append('../models')
from mnist.lenet import LeNet
from cifar10.vgg import VGG
from cifar10.resnet import ResNet18

from nni.compression.pytorch.utils.counter import count_flops_params

Expand Down Expand Up @@ -119,6 +120,12 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR(
optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
elif args.model == 'resnet18':
model = ResNet18().to(device)
if args.pretrained_model_dir is None:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR(
optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
else:
raise ValueError("model not recognized")

Expand Down Expand Up @@ -253,14 +260,19 @@ def trainer(model, optimizer, criterion, epoch):
'sparsity': args.sparsity,
'op_types': ['BatchNorm2d'],
}]
else:
elif args.model == 'resnet18':
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d'],
'op_names': ['feature.0', 'feature.10', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
'op_types': ['Conv2d']
}, {
'exclude': True,
'op_names': ['feature.10']
'op_names': ['layer1.0.conv1', 'layer1.0.conv2']
}]
else:
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]

pruner = pruner_cls(model, config_list, **kw_args)
Expand Down

0 comments on commit 1c98049

Please sign in to comment.