Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Rename x-axis label of Intermediate result "Trial No.x" #2145

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/model_compress/model_prune_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def create_model(model_name='naive'):
else:
return VGG(19)

def create_pruner(model, pruner_name, optimizer=None):
def create_pruner(model, pruner_name):
pruner_class = prune_config[pruner_name]['pruner_class']
config_list = prune_config[pruner_name]['config_list']
return pruner_class(model, config_list, optimizer)
return pruner_class(model, config_list)

def train(model, device, train_loader, optimizer):
model.train()
Expand Down Expand Up @@ -179,7 +179,6 @@ def test(model, device, test_loader):

def main(args):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
os.makedirs(args.checkpoints_dir, exist_ok=True)

model_name = prune_config[args.pruner_name]['model_name']
dataset_name = prune_config[args.pruner_name]['dataset_name']
Expand All @@ -204,6 +203,8 @@ def main(args):

print('start model pruning...')

if not os.path.exists(args.checkpoints_dir):
os.makedirs(args.checkpoints_dir)
model_path = os.path.join(args.checkpoints_dir, 'pruned_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))
mask_path = os.path.join(args.checkpoints_dir, 'mask_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))

Expand Down
7 changes: 0 additions & 7 deletions examples/trials/mnist-pytorch/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@
import torch.optim as optim
from torchvision import datasets, transforms

# Temporary patch this example until the MNIST dataset download issue get resolved
# https://github.com/pytorch/vision/issues/1938
import urllib

opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

logger = logging.getLogger('mnist_AutoML')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner):
to achieve a preset level of network sparsity.
"""

def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
Expand All @@ -25,8 +25,6 @@ def __init__(self, model, config_list, optimizer=None, activation='relu', statis
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Expand Down Expand Up @@ -107,7 +105,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1607.03250
"""

def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
Expand All @@ -116,8 +114,6 @@ def __init__(self, model, config_list, optimizer=None, activation='relu', statis
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Expand Down Expand Up @@ -181,7 +177,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1611.06440
"""

def __init__(self, model, config_list, optimizer=None, activation='relu', statistics_batch_num=1):
def __init__(self, model, config_list, optimizer, activation='relu', statistics_batch_num=1):
"""
Parameters
----------
Expand All @@ -190,8 +186,6 @@ def __init__(self, model, config_list, optimizer=None, activation='relu', statis
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
Activation function
statistics_batch_num : int
Expand Down
10 changes: 4 additions & 6 deletions src/sdk/pynni/nni/compression/torch/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Compressor:
Abstract base PyTorch compressor
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer):
"""
Record necessary info in class members

Expand Down Expand Up @@ -235,8 +235,7 @@ def new_step(_, *args, **kwargs):
task()
return output
return new_step
if self.optimizer is not None:
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)
self.optimizer.step = types.MethodType(patch_step(self.optimizer.step), self.optimizer)

class PrunerModuleWrapper(torch.nn.Module):
def __init__(self, module, module_name, module_type, config, pruner):
Expand Down Expand Up @@ -291,10 +290,9 @@ class Pruner(Compressor):

"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer):
super().__init__(model, config_list, optimizer)
if optimizer is not None:
self.patch_optimizer(self.update_mask)
self.patch_optimizer(self.update_mask)

def compress(self):
self.update_mask()
Expand Down
16 changes: 3 additions & 13 deletions src/sdk/pynni/nni/compression/torch/pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@ class LevelPruner(Pruner):
Prune to an exact pruning level specification
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -80,13 +78,9 @@ def __init__(self, model, config_list, optimizer):
Model to be pruned
config_list : list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"

self.now_epoch = 0
self.set_wrappers_attribute("if_calculated", False)

Expand Down Expand Up @@ -182,17 +176,13 @@ class SlimPruner(Pruner):
https://arxiv.org/pdf/1708.06519.pdf
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -254,7 +244,7 @@ class LotteryTicketPruner(Pruner):
5. Repeat step 2, 3, and 4.
"""

def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True):
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
"""
Parameters
----------
Expand Down
15 changes: 3 additions & 12 deletions src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner):
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
Expand All @@ -24,8 +24,6 @@ def __init__(self, model, config_list, optimizer=None):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -85,7 +83,7 @@ class L1FilterPruner(WeightRankFilterPruner):
https://arxiv.org/abs/1608.08710
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
Expand All @@ -94,8 +92,6 @@ def __init__(self, model, config_list, optimizer=None):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -135,7 +131,7 @@ class L2FilterPruner(WeightRankFilterPruner):
smallest L2 norm of the weights.
"""

def __init__(self, model, config_list, optimizer=None):
def __init__(self, model, config_list, optimizer):
"""
Parameters
----------
Expand All @@ -144,8 +140,6 @@ def __init__(self, model, config_list, optimizer=None):
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""

super().__init__(model, config_list, optimizer)
Expand Down Expand Up @@ -193,11 +187,8 @@ def __init__(self, model, config_list, optimizer):
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
super().__init__(model, config_list, optimizer)
assert isinstance(optimizer, torch.optim.Optimizer), "FPGM pruner is an iterative pruner, please pass optimizer of the model to it"

def get_mask(self, base_mask, weight, num_prune):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import datetime
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
from .model_factory import CurveModel

logger = logging.getLogger('curvefitting_Assessor')
Expand Down Expand Up @@ -92,11 +91,10 @@ def assess_trial(self, trial_job_id, trial_history):
Exception
unrecognize exception in curvefitting_assessor
"""
scalar_trial_history = extract_scalar_history(trial_history)
self.trial_history = scalar_trial_history
self.trial_history = trial_history
if not self.set_best_performance:
return AssessResult.Good
curr_step = len(scalar_trial_history)
curr_step = len(trial_history)
if curr_step < self.start_step:
return AssessResult.Good

Expand All @@ -108,7 +106,7 @@ def assess_trial(self, trial_job_id, trial_history):
start_time = datetime.datetime.now()
# Predict the final result
curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(scalar_trial_history)
predict_y = curvemodel.predict(trial_history)
logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None:
logger.info('wait for more information to predict precisely')
Expand Down
17 changes: 12 additions & 5 deletions src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import logging
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history

logger = logging.getLogger('medianstop_Assessor')

Expand Down Expand Up @@ -92,12 +91,20 @@ def assess_trial(self, trial_job_id, trial_history):
if curr_step < self._start_step:
return AssessResult.Good

scalar_trial_history = extract_scalar_history(trial_history)
self._update_data(trial_job_id, scalar_trial_history)
try:
num_trial_history = [float(ele) for ele in trial_history]
except (TypeError, ValueError) as error:
logger.warning('incorrect data type or value:')
logger.exception(error)
except Exception as error:
logger.warning('unrecognized exception in medianstop_assessor:')
logger.exception(error)

self._update_data(trial_job_id, num_trial_history)
if self._high_better:
best_history = max(scalar_trial_history)
best_history = max(trial_history)
else:
best_history = min(scalar_trial_history)
best_history = min(trial_history)

avg_array = []
for id_ in self._completed_avg_history:
Expand Down
1 change: 0 additions & 1 deletion src/sdk/pynni/nni/msg_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,4 @@ def _earlystop_notify_tuner(self, data):
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
31 changes: 1 addition & 30 deletions src/sdk/pynni/nni/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,6 @@ def extract_scalar_reward(value, scalar_key='default'):
"""
Extract scalar reward from trial result.

Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number

Raises
------
RuntimeError
Expand All @@ -85,26 +78,6 @@ def extract_scalar_reward(value, scalar_key='default'):
return reward


def extract_scalar_history(trial_history, scalar_key='default'):
"""
Extract scalar value from a list of intermediate results.

Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number

Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return [extract_scalar_reward(ele, scalar_key) for ele in trial_history]


def convert_dict2tuple(value):
"""
convert dict type to tuple to solve unhashable problem.
Expand All @@ -117,9 +90,7 @@ def convert_dict2tuple(value):


def init_dispatcher_logger():
"""
Initialize dispatcher logging configuration
"""
""" Initialize dispatcher logging configuration"""
logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
Expand Down
Loading