Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Retiarii graph and code generation test #3231

Merged
merged 8 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nni/retiarii/converter/op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class OpTypeName(str, Enum):
'aten::size': 'Size',
'aten::view': 'View',
'aten::eq': 'Eq',
'aten::Bool': 'Bool',
'aten::empty': 'Empty',
'aten::zeros': 'Zeros',
'aten::chunk': 'Chunk',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}

Expand Down
7 changes: 5 additions & 2 deletions nni/retiarii/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
return f'{output} = {value}'
elif self.type == 'prim::ListConstruct':
return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'prim::GetAttr':
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
elif self.type == 'aten::mean':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
elif self.type == 'aten::__getitem__':
Expand All @@ -133,8 +135,7 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
elif self.type == 'aten::add':
assert len(inputs) == 2
return f'{output} = {inputs[0]} + {inputs[1]}'
return f'{output} = ' + ' + '.join(inputs)
elif self.type == OpTypeName.MergedSlice:
assert (len(inputs) - 1) % 4 == 0
slices = []
Expand All @@ -151,6 +152,8 @@ def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
return f'{output} = {inputs[0]}.view({inputs[1]})'
elif self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
elif self.type == 'aten::Bool':
return f'{output} = bool({inputs[0]})'
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')

Expand Down
7 changes: 6 additions & 1 deletion nni/retiarii/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def get_records():
return _records


def clear_records():
global _records
_records = {}


def add_record(key, value):
"""
"""
Expand Down Expand Up @@ -56,7 +61,7 @@ def __init__(self, *args, **kwargs):
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list)):
if not isinstance(full_args[k], (int, float, str, dict, list, tuple)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
Expand Down
Loading