From 7fa9e420b282faae124635a2d470ba263534010a Mon Sep 17 00:00:00 2001 From: Daniel Rasmussen Date: Tue, 27 Oct 2020 10:16:58 -0300 Subject: [PATCH] fixup! Rework LMU API --- lmu/tests/test_layers.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/lmu/tests/test_layers.py b/lmu/tests/test_layers.py index b5a8b8ef..c820f482 100644 --- a/lmu/tests/test_layers.py +++ b/lmu/tests/test_layers.py @@ -182,10 +182,12 @@ def test_save_load_serialization(mode, tmp_path): @pytest.mark.parametrize("return_sequences", (True, False)) -def test_fft(return_sequences, rng): +@pytest.mark.parametrize( + "hidden_cell", (None, tf.keras.layers.Dense(4), tf.keras.layers.SimpleRNNCell(4)) +) +def test_fft(return_sequences, hidden_cell, rng): x = rng.uniform(-1, 1, size=(2, 10, 32)) - hidden_cell = tf.keras.layers.SimpleRNNCell(4) rnn_layer = tf.keras.layers.RNN( layers.LMUCell(1, 2, 3, hidden_cell), return_sequences=return_sequences, @@ -200,6 +202,13 @@ def test_fft(return_sequences, rng): assert np.allclose(rnn_out, fft_out, atol=2e-6) +def test_fft_errors(): + fft_layer = layers.LMUFFT(1, 2, 3, None) + + with pytest.raises(ValueError, match="temporal axis be fully specified"): + fft_layer(tf.keras.Input((None, 32))) + + @pytest.mark.parametrize( "hidden_to_memory, memory_to_memory, memory_d", [(False, False, 1), (True, False, 1), (False, True, 1), (False, False, 2)],