Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Dev nas compile -- code generator #1067

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion src/sdk/pynni/nni/smartparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
'qnormal',
'lognormal',
'qlognormal',
'function_choice'
'function_choice',
'mutable_layer'
]


Expand Down Expand Up @@ -115,6 +116,52 @@ def qlognormal(mu, sigma, q, name=None, key=None):
def function_choice(funcs, name=None, key=None):
return funcs[_get_param(key)]()

def mutable_layer(
mutable_id,
mutable_layer_id,
funcs,
funcs_args,
fixed_inputs,
optional_inputs,
optional_input_size=0,
debug_mode=False,
name=None,
key=None):
'''execute the chosen function and inputs.

Below is an example of chosen function and inputs:
{
"mutable_id": {
"mutable_layer_id": {
"chosen_layer": "pool",
"chosen_inputs": ["out1", "out3"]
}
}
}

Parameters:
---------------
mutable_id: the name of this mutable_layer block (which could have multiple mutable layers)
mutable_layer_id: the name of a mutable layer in this block
funcs: dict of function alls
funcs_args:
fixed_inputs:
optional_inputs: dict of optional inputs
optional_input_size: number of candidate inputs to be chosen
debug_mode: if True, every choice will use the first one(s)
'''
if debug_mode:
# TODO:
return

mutable_block = _get_param(mutable_id)
chosen_layer = mutable_block[mutable_layer_id]["chosen_layer"]
chosen_inputs = mutable_block[mutable_layer_id]["chosen_inputs"]
real_chosen_inputs = [optional_inputs[input_name] for input_name in chosen_inputs]
layer_out = funcs[chosen_layer]([fixed_inputs, real_chosen_inputs], *funcs_args[chosen_layer])

return layer_out

def _get_param(key):
if trial._params is None:
trial.get_next_parameter()
Expand Down
90 changes: 90 additions & 0 deletions tools/nni_annotation/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,93 @@

# pylint: disable=unidiomatic-typecheck

def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
code: annotation string (excluding '@')
"""
module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1'
assert len(module.body) == 1, 'Annotation mutable_layers contains more than one expression'
assert type(module.body[0]) is ast.Expr, 'Annotation is not expression'
call = module.body[0].value
nodes = []
mutable_id = 'mutable_block_' + str(lineno)
mutable_layer_cnt = 0
for arg in call.args:
fields = {'layer_choice': False,
'fixed_inputs': False,
'optional_inputs': False,
'optional_input_size': False,
'layer_output': False}
for k, value in zip(arg.keys, arg.values):
if k.id == 'layer_choice':
assert not fields['layer_choice'], 'Duplicated field: layer_choice'
assert type(value) is ast.List, 'Value of layer_choice should be a list'
call_funcs_keys = []
call_funcs_values = []
call_kwargs_values = []
for call in value.elts:
assert type(call) is ast.Call, 'Element in layer_choice should be function call'
call_name = astor.to_source(call).strip()
call_funcs_keys.append(ast.Str(s=call_name))
call_funcs_values.append(call.func)
assert not call.args, 'Number of args without keyword should be zero'
kw_args = []
kw_values = []
for kw in call.keywords:
kw_args.append(kw.arg)
kw_values.append(kw.value)
call_kwargs_values.append(ast.Dict(keys=kw_args, values=kw_values))
call_funcs = ast.Dict(keys=call_funcs_keys, values=call_funcs_values)
call_kwargs = ast.Dict(keys=call_funcs_keys, values=call_kwargs_values)
fields['layer_choice'] = True
elif k.id == 'fixed_inputs':
assert not fields['fixed_inputs'], 'Duplicated field: fixed_inputs'
assert type(value) is ast.List, 'Value of fixed_inputs should be a list'
fixed_inputs = value
fields['fixed_inputs'] = True
elif k.id == 'optional_inputs':
assert not fields['optional_inputs'], 'Duplicated field: optional_inputs'
assert type(value) is ast.List, 'Value of optional_inputs should be a list'
optional_inputs = value
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num, 'Value of optional_input_size should be a number'
optional_input_size = value
fields['optional_input_size'] = True
elif k.id == 'layer_output':
assert not fields['layer_output'], 'Duplicated field: layer_output'
assert type(value) is ast.Name, 'Value of layer_output should be ast.Name type'
layer_output = value
fields['layer_output'] = True
else:
raise AssertionError('Unexpected field in mutable layer')
# make call for this mutable layer
assert fields['layer_choice'], 'layer_choice must exist'
assert fields['layer_output'], 'layer_output must exist'
mutable_layer_id = 'mutable_layer_' + str(mutable_layer_cnt)
mutable_layer_cnt += 1
target_call_attr = ast.Attribute(value=ast.Name(id='nni', ctx=ast.Load()), attr='mutable_layer', ctx=ast.Load())
target_call_args = [ast.Str(s=mutable_id),
ast.Str(s=mutable_layer_id),
call_funcs,
call_kwargs]
if fields['fixed_inputs']:
target_call_args.append(fixed_inputs)
else:
target_call_args.append(ast.NameConstant(value=None))
if fields['optional_inputs']:
target_call_args.append(optional_inputs)
assert fields['optional_input_size'], 'optional_input_size must exist when optional_inputs exists'
target_call_args.append(optional_input_size)
else:
target_call_args.append(ast.NameConstant(value=None))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
node = ast.Assign(targets=[layer_output], value=target_call)
nodes.append(node)
return nodes

def parse_annotation(code):
"""Parse an annotation string.
Expand Down Expand Up @@ -235,6 +322,9 @@ def _visit_string(self, node):
or string.startswith('@nni.get_next_parameter('):
return parse_annotation(string[1:]) # expand annotation string to code

if string.startswith('@nni.mutable_layers('):
return parse_annotation_mutable_layers(string[1:], node.lineno)

if string.startswith('@nni.variable(') \
or string.startswith('@nni.function_choice('):
self.stack[-1] = string[1:] # mark that the next expression is annotated
Expand Down
57 changes: 57 additions & 0 deletions tools/nni_annotation/testcase/mutable_layer_usercode/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import time

def add_one(inputs):
return inputs + 1

def add_two(inputs):
return inputs + 2

def add_three(inputs):
return inputs + 3

def add_four(inputs):
return inputs + 4


def main():

images = 5

"""@nni.mutable_layers(
{
layer_choice: [add_one(), add_two(), add_three(), add_four()],
optional_inputs: [images],
optional_input_size: 1,
layer_output: layer_1_out
},
{
layer_choice: [add_one(), add_two(), add_three(), add_four()],
optional_inputs: [layer_1_out],
optional_input_size: 1,
layer_output: layer_2_out
},
{
layer_choice: [add_one(), add_two(), add_three(), add_four()],
optional_inputs: [layer_1_out, layer_2_out],
optional_input_size: 1,
layer_output: layer_3_out
}
)"""

"""@nni.report_intermediate_result(layer_1_out)"""
time.sleep(2)
"""@nni.report_intermediate_result(layer_2_out)"""
time.sleep(2)
"""@nni.report_intermediate_result(layer_3_out)"""
time.sleep(2)

layer_3_out = layer_3_out + 10

"""@nni.report_final_result(layer_3_out)"""

if __name__ == '__main__':
try:
main()
except Exception as exception:
print(exception)
raise