Skip to content

Commit

Permalink
squash! Allow memory_d > 1 for LMUFFT
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Hunsberger <eric.hunsberger@appliedbrainresearch.com>
  • Loading branch information
hunse committed Jun 15, 2021
1 parent bf14e97 commit 9d61854
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,26 @@ def test_save_load_serialization(mode, tmp_path):

@pytest.mark.parametrize("return_sequences", (True, False))
@pytest.mark.parametrize(
"hidden_cell", (None, tf.keras.layers.Dense(4), tf.keras.layers.SimpleRNNCell(4))
"hidden_cell",
(
lambda: None,
lambda: tf.keras.layers.Dense(4),
lambda: tf.keras.layers.SimpleRNNCell(4),
),
)
def test_fft(return_sequences, hidden_cell, rng):
@pytest.mark.parametrize("memory_d", [1, 4])
def test_fft(return_sequences, hidden_cell, memory_d, rng):
kwargs = dict(memory_d=memory_d, order=2, theta=3, hidden_cell=hidden_cell())

x = rng.uniform(-1, 1, size=(2, 10, 32))

rnn_layer = tf.keras.layers.RNN(
layers.LMUCell(1, 2, 3, hidden_cell),
layers.LMUCell(**kwargs),
return_sequences=return_sequences,
)
rnn_out = rnn_layer(x)

fft_layer = layers.LMUFFT(1, 2, 3, hidden_cell, return_sequences=return_sequences)
fft_layer = layers.LMUFFT(return_sequences=return_sequences, **kwargs)
fft_layer.build(x.shape)
fft_layer.kernel.assign(rnn_layer.cell.kernel)
fft_out = fft_layer(x)
Expand Down Expand Up @@ -423,30 +431,3 @@ def test_fit(fft):
assert isinstance(lmu_layer.layer, tf.keras.layers.RNN)

assert acc == 1.0


def test_multidim_fft(rng):
# Test if LMUFFT with memory_d > 1 works the same way as multi-dimensional LMUCell
memory_d = 4
order = 16
n_steps = 10
input_d = 8

inp = tf.keras.Input(shape=(n_steps, input_d))

kwargs = dict(
memory_d=memory_d,
order=order,
theta=n_steps,
hidden_cell=None,
kernel_initializer="ones",
trainable=False,
)
lmu_cell = tf.keras.layers.RNN(layers.LMUCell(**kwargs), return_sequences=True)(inp)
lmu_fft = layers.LMUFFT(return_sequences=True, **kwargs)(inp)

model = tf.keras.Model(inp, [lmu_cell, lmu_fft])

results = model.predict(rng.uniform(-1, 1, size=(1, n_steps, input_d)))

assert np.allclose(results[0], results[1], atol=2e-6)

0 comments on commit 9d61854

Please sign in to comment.