-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Retiarii] refactor based on the new launch approach #3185
[Retiarii] refactor based on the new launch approach #3185
Conversation
@@ -52,7 +52,7 @@ def __init__(self, experiment_id: str): | |||
|
|||
def connect(self) -> BufferedIOBase: | |||
conn, _ = self._socket.accept() | |||
self.file = conn.makefile('w+b') | |||
self.file = conn.makefile('rwb') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#3183 @liuzhe-lz can answer this question
while True: | ||
time.sleep(10) | ||
status = self.get_status() | ||
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the experiment is done?
if _records is not None: | ||
assert name not in _records, '{} already in _records'.format(name) | ||
_records[name] = value | ||
__all__ = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this nn.__all__
plus a few others?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, ours, for example placeholder
BatchNorm2d = wrap_module(nn.BatchNorm2d) | ||
ReLU = wrap_module(nn.ReLU) | ||
Dropout = wrap_module(nn.Dropout) | ||
Linear = wrap_module(nn.Linear) | ||
MaxPool2d = wrap_module(nn.MaxPool2d) | ||
AvgPool2d = wrap_module(nn.AvgPool2d) | ||
Identity = wrap_module(nn.Identity) | ||
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)''' | ||
|
||
Identity = wrap_module(nn.Identity) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we refactor it into a for loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this way is more clear and easy to maintain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because not all the modules are included, for example, loss
modules. and different versions of pytorch have different members in nn.module. I will handle different versions later
nni/retiarii/converter/graph_gen.py
Outdated
@@ -463,6 +466,9 @@ def convert_module(script_module, module, module_name, ir_model): | |||
|
|||
ir_graph._register() | |||
|
|||
if id(module) not in modules_arg: | |||
raise RuntimeError(f'{original_type_name} arguments are not recorded, \ | |||
you may forget to decorate this class with @register_module()') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might have forgotten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
with
statement