Skip to content

Commit

Permalink
Fix prioritiezd board update and rand_parameters for cream nas (micro…
Browse files Browse the repository at this point in the history
  • Loading branch information
alibaba-yiwuyao authored and Hao Ni committed Apr 7, 2021
1 parent fed9b58 commit aa42ddd
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
10 changes: 5 additions & 5 deletions examples/nas/cream/lib/models/structures/supernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def rand_parameters(self, architecture, meta=False):
yield param

if not meta:
for layer, layer_arch in zip(self.blocks, architecture):
for blocks, arch in zip(layer, layer_arch):
if arch == -1:
for choice_blocks, choice_name in zip(self.blocks, architecture):
choice_sample = architecture[choice_name]
for block, arch in zip(choice_blocks, choice_sample):
if not arch:
continue
for name, param in blocks[arch].named_parameters(
recurse=True):
for name, param in block.named_parameters(recurse=True):
yield param


Expand Down
4 changes: 1 addition & 3 deletions nni/algorithms/nas/pytorch/cream/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flop
(val_prec1,
prec1,
flops,
self.current_teacher_arch,
self.current_student_arch,
training_data,
torch.nn.functional.softmax(
features,
Expand All @@ -174,8 +174,6 @@ def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flop
self.prioritized_board, reverse=True)

if len(self.prioritized_board) > self.pool_size:
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
del self.prioritized_board[-1]

# only update student network weights
Expand Down

0 comments on commit aa42ddd

Please sign in to comment.