diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index 2c0d5c72..9e4c3918 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -777,9 +777,9 @@ def build(self, input_shape): self.impulse_response = tf.reshape( self.impulse_response, - (1, self.impulse_response.shape[0], 1, self.order), + (self.impulse_response.shape[0], 1, 1, self.order), ) - self.impulse_response = self.impulse_response[:, ::-1, :, :] + self.impulse_response = self.impulse_response[::-1, :, :, :] if self.kernel_initializer is not None: self.kernel = self.add_weight( @@ -877,29 +877,23 @@ def _fft_convolution(self, u): def _raw_convolution(self, u): seq_len = tf.shape(u)[1] - ir_len = self.impulse_response.shape[1] - - # it's more efficient to do convolution along the W dimension, so move - # signal dimension to W (`u` shape will now be `(batch, memory_d, timesteps)`) - u = tf.transpose(u, perm=[0, 2, 1]) + ir_len = self.impulse_response.shape[0] if self.conv_mode == "raw_nchw": # pragma: no cover - u = tf.reshape(u, (-1, 1, 1, seq_len)) # combine batch and memory_d axes - padding = [[0, 0], [0, 0], [0, 0], [ir_len - 1, 0]] + u = tf.expand_dims(u, 1) + padding = [[0, 0], [0, 0], [ir_len - 1, 0], [0, 0]] m = tf.nn.conv2d( u, self.impulse_response, strides=1, data_format="NCHW", padding=padding ) - m = tf.reshape(m, (-1, self.memory_d * self.order, seq_len)) - m = tf.transpose(m, perm=[0, 2, 1]) + m = tf.transpose(m, perm=[0, 2, 3, 1]) else: u = tf.expand_dims(u, -1) - padding = [[0, 0], [0, 0], [ir_len - 1, 0], [0, 0]] + padding = [[0, 0], [ir_len - 1, 0], [0, 0], [0, 0]] m = tf.nn.conv2d( u, self.impulse_response, strides=1, data_format="NHWC", padding=padding ) - m = tf.transpose(m, perm=[0, 2, 1, 3]) - m = tf.reshape(m, (-1, seq_len, self.memory_d * self.order)) + m = tf.reshape(m, (-1, seq_len, self.memory_d * self.order)) return m def get_config(self):