Skip to content

Commit

Permalink
Switch to raw convolution along H dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
hunse committed Aug 11, 2021
1 parent 13b0439 commit e1fabb0
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,9 +777,9 @@ def build(self, input_shape):

self.impulse_response = tf.reshape(
self.impulse_response,
(1, self.impulse_response.shape[0], 1, self.order),
(self.impulse_response.shape[0], 1, 1, self.order),
)
self.impulse_response = self.impulse_response[:, ::-1, :, :]
self.impulse_response = self.impulse_response[::-1, :, :, :]

if self.kernel_initializer is not None:
self.kernel = self.add_weight(
Expand Down Expand Up @@ -877,29 +877,23 @@ def _fft_convolution(self, u):

def _raw_convolution(self, u):
seq_len = tf.shape(u)[1]
ir_len = self.impulse_response.shape[1]

# it's more efficient to do convolution along the W dimension, so move
# signal dimension to W (`u` shape will now be `(batch, memory_d, timesteps)`)
u = tf.transpose(u, perm=[0, 2, 1])
ir_len = self.impulse_response.shape[0]

if self.conv_mode == "raw_nchw": # pragma: no cover
u = tf.reshape(u, (-1, 1, 1, seq_len)) # combine batch and memory_d axes
padding = [[0, 0], [0, 0], [0, 0], [ir_len - 1, 0]]
u = tf.expand_dims(u, 1)
padding = [[0, 0], [0, 0], [ir_len - 1, 0], [0, 0]]
m = tf.nn.conv2d(
u, self.impulse_response, strides=1, data_format="NCHW", padding=padding
)
m = tf.reshape(m, (-1, self.memory_d * self.order, seq_len))
m = tf.transpose(m, perm=[0, 2, 1])
m = tf.transpose(m, perm=[0, 2, 3, 1])
else:
u = tf.expand_dims(u, -1)
padding = [[0, 0], [0, 0], [ir_len - 1, 0], [0, 0]]
padding = [[0, 0], [ir_len - 1, 0], [0, 0], [0, 0]]
m = tf.nn.conv2d(
u, self.impulse_response, strides=1, data_format="NHWC", padding=padding
)
m = tf.transpose(m, perm=[0, 2, 1, 3])
m = tf.reshape(m, (-1, seq_len, self.memory_d * self.order))

m = tf.reshape(m, (-1, seq_len, self.memory_d * self.order))
return m

def get_config(self):
Expand Down

0 comments on commit e1fabb0

Please sign in to comment.