Skip to content

Commit

Permalink
fixup! Allow memory_d > 1 for LMUFFT
Browse files Browse the repository at this point in the history
  • Loading branch information
hunse committed Jun 15, 2021
1 parent 0339cd4 commit edcdf20
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,18 +654,13 @@ def call(self, inputs, training=None):
# Perform the FFT
fft_input = tf.signal.rfft(u, fft_length=[2 * seq_len], name="input_pad")

# Expand dimensions
fft_input = tf.expand_dims(fft_input, axis=-2)

# Elementwise product of FFT (broadcasting done automatically)
result = fft_input * self.impulse_response
# Elementwise product of FFT (with broadcasting)
result = tf.expand_dims(fft_input, axis=-2) * self.impulse_response

# Inverse FFT
m = tf.signal.irfft(result, fft_length=[2 * seq_len])[..., :seq_len]

# Reshaping
m = tf.reshape(m, (-1, self.order * self.memory_d, seq_len))

m = tf.transpose(m, perm=[0, 2, 1])

# apply hidden cell
Expand Down

0 comments on commit edcdf20

Please sign in to comment.