Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Udpated concat axis to match image_data_format in keras (#2946)
Browse files Browse the repository at this point in the history
  • Loading branch information
code-fury authored Oct 19, 2020
1 parent 58873c4 commit 3ffd105
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/sdk/pynni/nni/nas/tensorflow/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == 'mean':
return sum(tensor_list) / len(tensor_list)
if reduction_type == 'concat':
return tf.concat(tensor_list, axis=0)
image_data_format = tf.keras.backend.image_data_format()
if image_data_format == "channels_first":
axis = 0
else:
axis = -1
return tf.concat(tensor_list, axis=axis)
raise ValueError('Unrecognized reduction policy: "{}'.format(reduction_type))

def _get_decision(self, mutable):
Expand Down

0 comments on commit 3ffd105

Please sign in to comment.