Skip to content

Commit

Permalink
Unify raw conv implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
hunse committed Aug 11, 2021
1 parent e1fabb0 commit 424f197
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,20 +879,16 @@ def _raw_convolution(self, u):
seq_len = tf.shape(u)[1]
ir_len = self.impulse_response.shape[0]

if self.conv_mode == "raw_nchw": # pragma: no cover
u = tf.expand_dims(u, 1)
padding = [[0, 0], [0, 0], [ir_len - 1, 0], [0, 0]]
m = tf.nn.conv2d(
u, self.impulse_response, strides=1, data_format="NCHW", padding=padding
)
channels_last = self.conv_mode != "raw_nchw"
u = tf.expand_dims(u, -1 if channels_last else 1)
padding = [[0, 0], [0, 0], [0, 0], [0, 0]]
padding[1 if channels_last else 2] = [ir_len - 1, 0]
fmt = "NHWC" if channels_last else "NCHW"
m = tf.nn.conv2d(
u, self.impulse_response, strides=1, data_format=fmt, padding=padding
)
if not channels_last:
m = tf.transpose(m, perm=[0, 2, 3, 1])
else:
u = tf.expand_dims(u, -1)
padding = [[0, 0], [ir_len - 1, 0], [0, 0], [0, 0]]
m = tf.nn.conv2d(
u, self.impulse_response, strides=1, data_format="NHWC", padding=padding
)

m = tf.reshape(m, (-1, seq_len, self.memory_d * self.order))
return m

Expand Down

0 comments on commit 424f197

Please sign in to comment.