Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Compressor updates #2136

Merged
merged 9 commits into from
Mar 11, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/model_compress/model_prune_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def create_model(model_name='naive'):
else:
return VGG(19)

def create_pruner(model, pruner_name):
def create_pruner(model, pruner_name, optimizer=None):
pruner_class = prune_config[pruner_name]['pruner_class']
config_list = prune_config[pruner_name]['config_list']
return pruner_class(model, config_list)
return pruner_class(model, config_list, optimizer)

def train(model, device, train_loader, optimizer):
model.train()
Expand Down Expand Up @@ -179,6 +179,8 @@ def test(model, device, test_loader):

def main(args):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if not os.path.exists(args.checkpoints_dir):
os.makedirs(args.checkpoints_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makedirs(..., exist_ok=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, updated.


model_name = prune_config[args.pruner_name]['model_name']
dataset_name = prune_config[args.pruner_name]['dataset_name']
Expand All @@ -203,8 +205,6 @@ def main(args):

print('start model pruning...')

if not os.path.exists(args.checkpoints_dir):
os.makedirs(args.checkpoints_dir)
model_path = os.path.join(args.checkpoints_dir, 'pruned_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))
mask_path = os.path.join(args.checkpoints_dir, 'mask_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))

Expand Down
7 changes: 7 additions & 0 deletions examples/trials/mnist-pytorch/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
import torch.optim as optim
from torchvision import datasets, transforms

# Temporary patch this example until the MNIST dataset download issue get resolved
# https://github.com/pytorch/vision/issues/1938
import urllib

opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

logger = logging.getLogger('mnist_AutoML')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner):
to achieve a preset level of network sparsity.
"""

def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
Expand All @@ -25,6 +25,8 @@ def __init__(self, model, config_list, optimizer, activation='relu', statistics_
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Expand Down Expand Up @@ -105,7 +107,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1607.03250
"""

def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
Expand All @@ -114,6 +116,8 @@ def __init__(self, model, config_list, optimizer, activation='relu', statistics_
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Expand Down Expand Up @@ -177,7 +181,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1611.06440
"""

def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
Expand All @@ -186,6 +190,8 @@ def __init__(self, model, config_list, optimizer, activation='relu', statistics_
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Expand Down
10 changes: 6 additions & 4 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Compressor:
Abstract base PyTorch compressor
"""

def __init__(self, model, config_list, optimizer):
def __init__(self, model, config_list, optimizer=None):
"""
Record necessary info in class members

Expand Down Expand Up @@ -235,7 +235,8 @@ def new_step(_, *args, **kwargs):
task()
return output
return new_step
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)

class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
Expand Down Expand Up @@ -290,9 +291,10 @@ class Pruner(Compressor):

"""

def __init__(self, model, config_list, optimizer):
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.patch_optimizer(self.update_mask)
if optimizer is not None:
self.patch_optimizer(self.update_mask)

def compress(self):
self.update_mask()
Expand Down
16 changes: 13 additions & 3 deletions src/sdk/pynni/nni/compression/torch/pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ class LevelPruner(Pruner):
Prune to an exact pruning level specification
"""

def __init__(self, model, config_list, optimizer):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -78,9 +80,13 @@ def __init__(self, model, config_list, optimizer):
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"

self.now_epoch = 0
self.set_wrappers_attribute("if_calculated", False)

Expand Down Expand Up @@ -176,13 +182,17 @@ class SlimPruner(Pruner):
https://arxiv.org/pdf/1708.06519.pdf
"""

def __init__(self, model, config_list, optimizer):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -244,7 +254,7 @@ class LotteryTicketPruner(Pruner):
5. Repeat step 2, 3, and 4.
"""

def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True):
"""
Parameters
----------
Expand Down
15 changes: 12 additions & 3 deletions src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner):
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""

def __init__(self, model, config_list, optimizer):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
Expand All @@ -24,6 +24,8 @@ def __init__(self, model, config_list, optimizer):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -83,7 +85,7 @@ class L1FilterPruner(WeightRankFilterPruner):
https://arxiv.org/abs/1608.08710
"""

def __init__(self, model, config_list, optimizer):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
Expand All @@ -92,6 +94,8 @@ def __init__(self, model, config_list, optimizer):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -131,7 +135,7 @@ class L2FilterPruner(WeightRankFilterPruner):
smallest L2 norm of the weights.
"""

def __init__(self, model, config_list, optimizer):
def __init__(self, model, config_list, optimizer=None):
"""
Parameters
----------
Expand All @@ -140,6 +144,8 @@ def __init__(self, model, config_list, optimizer):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -187,8 +193,11 @@ def __init__(self, model, config_list, optimizer):
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "FPGM pruner is an iterative pruner, please pass optimizer of the model to it"

def get_mask(self, base_mask, weight, num_prune):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import datetime
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
from .model_factory import CurveModel

logger = logging.getLogger('curvefitting_Assessor')
Expand Down Expand Up @@ -91,10 +92,11 @@ def assess_trial(self, trial_job_id, trial_history):
Exception
unrecognize exception in curvefitting_assessor
"""
self.trial_history = trial_history
scalar_trial_history = extract_scalar_history(trial_history)
self.trial_history = scalar_trial_history
if not self.set_best_performance:
return AssessResult.Good
curr_step = len(trial_history)
curr_step = len(scalar_trial_history)
if curr_step < self.start_step:
return AssessResult.Good

Expand All @@ -106,7 +108,7 @@ def assess_trial(self, trial_job_id, trial_history):
start_time = datetime.datetime.now()
# Predict the final result
curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(trial_history)
predict_y = curvemodel.predict(scalar_trial_history)
logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None:
logger.info('wait for more information to predict precisely')
Expand Down
17 changes: 5 additions & 12 deletions src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history

logger = logging.getLogger('medianstop_Assessor')

Expand Down Expand Up @@ -91,20 +92,12 @@ def assess_trial(self, trial_job_id, trial_history):
if curr_step < self._start_step:
return AssessResult.Good

try:
num_trial_history = [float(ele) for ele in trial_history]
except (TypeError, ValueError) as error:
logger.warning('incorrect data type or value:')
logger.exception(error)
except Exception as error:
logger.warning('unrecognized exception in medianstop_assessor:')
logger.exception(error)

self._update_data(trial_job_id, num_trial_history)
scalar_trial_history = extract_scalar_history(trial_history)
self._update_data(trial_job_id, scalar_trial_history)
if self._high_better:
best_history = max(trial_history)
best_history = max(scalar_trial_history)
else:
best_history = min(trial_history)
best_history = min(scalar_trial_history)

avg_array = []
for id_ in self._completed_avg_history:
Expand Down
1 change: 1 addition & 0 deletions src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,4 +234,5 @@ def _earlystop_notify_tuner(self, data):
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
31 changes: 30 additions & 1 deletion src/sdk/pynni/nni/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
"""
Extract scalar reward from trial result.

Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number

Raises
------
RuntimeError
Expand All @@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
return reward


def extract_scalar_history(trial_history, scalar_key='default'):
"""
Extract scalar value from a list of intermediate results.

Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number

Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return [extract_scalar_reward(ele, scalar_key) for ele in trial_history]


def convert_dict2tuple(value):
"""
convert dict type to tuple to solve unhashable problem.
Expand All @@ -90,7 +117,9 @@ def convert_dict2tuple(value):


def init_dispatcher_logger():
""" Initialize dispatcher logging configuration"""
"""
Initialize dispatcher logging configuration
"""
logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
Expand Down
Loading