Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Resolve comments in nn-meter branch #3983

Merged
merged 1 commit into from
Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/en_US/NAS/HardwareAwareNAS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ To support latency-aware NAS, you first need a `Strategy` that supports filterin
``LatencyFilter`` will predict the models\' latency by using nn-Meter and filter out the models whose latency are larger than the threshold (i.e., ``100`` in this example).
You can also build your own strategies and filters to support more flexible NAS such as sorting the models according to latency.

Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, example_inputs=example_inputs``:
Then, pass this strategy to ``RetiariiExperiment`` along with some additional arguments: ``parse_shape=True, dummy_input=dummy_input``:

.. code-block:: python

RetiariiExperiment(base_model, trainer, [], simple_strategy, True, example_inputs)
RetiariiExperiment(base_model, trainer, [], simple_strategy, True, dummy_input)

Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``example_inputs`` is required for tracing shape info.
Here, ``parse_shape=True`` means extracting shape info from the torch model as it is required by nn-Meter to predict latency. ``dummy_input`` is required for tracing shape info.
2 changes: 1 addition & 1 deletion examples/nas/oneshot/spos/multi_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _main(port):
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False
exp_config.execution_engine = 'base'
exp_config.example_inputs = [1, 3, 32, 32]
exp_config.dummy_input = [1, 3, 32, 32]

exp.run(exp_config, port)

Expand Down
12 changes: 6 additions & 6 deletions nni/retiarii/converter/graph_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,13 +683,13 @@ class GraphConverterWithShape(GraphConverter):
If forward path of candidates depends on input data, then wrong path will be traced.
This will result in incomplete shape info.
"""
def convert_module(self, script_module, module, module_name, ir_model, example_inputs):
def convert_module(self, script_module, module, module_name, ir_model, dummy_input):
module.eval()

ir_graph, attrs = self._convert_module(script_module, module, module_name, ir_model)
self.remove_dummy_nodes(ir_model)
self._initialize_parameters(ir_model)
self._trace_module(module, module_name, ir_model, example_inputs)
self._trace_module(module, module_name, ir_model, dummy_input)
return ir_graph, attrs

def _initialize_parameters(self, ir_model: 'Model'):
Expand All @@ -699,9 +699,9 @@ def _initialize_parameters(self, ir_model: 'Model'):
ir_node.operation.parameters.setdefault('input_shape', [])
ir_node.operation.parameters.setdefault('output_shape', [])

def _trace_module(self, module, module_name, ir_model: 'Model', example_inputs):
def _trace_module(self, module, module_name, ir_model: 'Model', dummy_input):
# First, trace the whole graph
tm_graph = self._trace(module, example_inputs)
tm_graph = self._trace(module, dummy_input)

for node in tm_graph.nodes():
parameters = _extract_info_from_trace_node(node)
Expand Down Expand Up @@ -832,8 +832,8 @@ def _flatten(graph: 'Graph'):
# remove subgraphs
ir_model.graphs = {ir_model._root_graph_name: ir_model.root_graph}

def _trace(self, module, example_inputs):
traced_module = torch.jit.trace(module, example_inputs)
def _trace(self, module, dummy_input):
traced_module = torch.jit.trace(module, dummy_input)
torch._C._jit_pass_inline(traced_module.graph)
return traced_module.graph

Expand Down
12 changes: 6 additions & 6 deletions nni/retiarii/experiment/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RetiariiExeConfig(ConfigBase):
execution_engine: str = 'py'

# input used in GraphConverterWithShape. Currently support shape tuple only.
example_inputs: Optional[List[int]] = None
dummy_input: Optional[List[int]] = None

def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
Expand Down Expand Up @@ -110,19 +110,19 @@ def _validation_rules(self):
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}

def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, example_inputs=None):
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_input=None):
# TODO: this logic might need to be refactored into execution engine
if full_ir:
try:
script_module = torch.jit.script(base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
if example_inputs is not None:
if dummy_input is not None:
# FIXME: this is a workaround as full tensor is not supported in configs
example_inputs = torch.randn(*example_inputs)
dummy_input = torch.randn(*dummy_input)
converter = GraphConverterWithShape()
base_model_ir = convert_to_graph(script_module, base_model, converter, example_inputs=example_inputs)
base_model_ir = convert_to_graph(script_module, base_model, converter, dummy_input=dummy_input)
else:
base_model_ir = convert_to_graph(script_module, base_model)
# handle inline mutations
Expand Down Expand Up @@ -182,7 +182,7 @@ def __init__(self, base_model: nn.Module, trainer: Union[Evaluator, BaseOneShotT
def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py',
example_inputs=self.config.example_inputs)
dummy_input=self.config.dummy_input)

_logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators)
Expand Down
2 changes: 1 addition & 1 deletion test/ut/retiarii/convert_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ class ConvertWithShapeMixin:
@staticmethod
def _convert_model(model, input):
script_module = torch.jit.script(model)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), example_inputs=input)
model_ir = convert_to_graph(script_module, model, converter=GraphConverterWithShape(), dummy_input=input)
return model_ir