Skip to content

Commit

Permalink
fixup! Rework LMU API
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Oct 27, 2020
1 parent 8d8d9ad commit 7fa9e42
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,12 @@ def test_save_load_serialization(mode, tmp_path):


@pytest.mark.parametrize("return_sequences", (True, False))
def test_fft(return_sequences, rng):
@pytest.mark.parametrize(
"hidden_cell", (None, tf.keras.layers.Dense(4), tf.keras.layers.SimpleRNNCell(4))
)
def test_fft(return_sequences, hidden_cell, rng):
x = rng.uniform(-1, 1, size=(2, 10, 32))

hidden_cell = tf.keras.layers.SimpleRNNCell(4)
rnn_layer = tf.keras.layers.RNN(
layers.LMUCell(1, 2, 3, hidden_cell),
return_sequences=return_sequences,
Expand All @@ -200,6 +202,13 @@ def test_fft(return_sequences, rng):
assert np.allclose(rnn_out, fft_out, atol=2e-6)


def test_fft_errors():
fft_layer = layers.LMUFFT(1, 2, 3, None)

with pytest.raises(ValueError, match="temporal axis be fully specified"):
fft_layer(tf.keras.Input((None, 32)))


@pytest.mark.parametrize(
"hidden_to_memory, memory_to_memory, memory_d",
[(False, False, 1), (True, False, 1), (False, True, 1), (False, False, 2)],
Expand Down

0 comments on commit 7fa9e42

Please sign in to comment.