Skip to content

Commit

Permalink
Allow memory_d > 1 for LMUFFT
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Hunsberger <eric.hunsberger@appliedbrainresearch.com>
  • Loading branch information
2 people authored and drasmuss committed Jun 16, 2021
1 parent d7601b7 commit 15d5ba9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 25 deletions.
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Release history
**Added**

- Setting ``kernel_initializer=None`` now removes the dense input kernel. (`#40`_)
- The ``keras_lmu.LMUFFT`` layer now supports ``memory_d > 1``. ``keras_lmu.LMU`` now
uses this implementation for all values of ``memory_d`` when feedforward conditions
are satisfied (no hidden-to-memory or memory-to-memory connections,
and the sequence length is not ``None``). (`#40`_)

.. _#40: https://github.com/nengo/keras-lmu/pull/40

Expand Down
18 changes: 6 additions & 12 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,6 @@ def build(self, input_shapes):
if (
not self.hidden_to_memory
and not self.memory_to_memory
and self.memory_d == 1
and input_shapes[1] is not None
):
self.layer = LMUFFT(
Expand Down Expand Up @@ -540,13 +539,6 @@ def __init__(
):
super().__init__(**kwargs)

if memory_d != 1:
# TODO: we can support this by reusing the same impulse response
# for each dimension
raise NotImplementedError(
"Multi-dimensional memory not supported in LMUFFT"
)

if input_to_hidden and hidden_cell is None:
raise ValueError("input_to_hidden must be False if hidden_cell is None")

Expand All @@ -559,9 +551,10 @@ def __init__(
self.dropout = dropout
self.return_sequences = return_sequences

# create a standard LMUCell to generate the impulse response during `build`
self.delay_layer = tf.keras.layers.RNN(
LMUCell(
memory_d=memory_d,
memory_d=1,
order=order,
theta=theta,
hidden_cell=None,
Expand Down Expand Up @@ -654,19 +647,20 @@ def call(self, inputs, training=None):
else tf.matmul(inputs, self.kernel, name="input_encoder_mult")
)

# FFT requires shape (batch, 1, timesteps)
# FFT requires shape (batch, memory_d, timesteps)
u = tf.transpose(u, perm=[0, 2, 1])

# Pad sequences to avoid circular convolution
# Perform the FFT
fft_input = tf.signal.rfft(u, fft_length=[2 * seq_len], name="input_pad")

# Elementwise product of FFT (broadcasting done automatically)
result = fft_input * self.impulse_response
# Elementwise product of FFT (with broadcasting)
result = tf.expand_dims(fft_input, axis=-2) * self.impulse_response

# Inverse FFT
m = tf.signal.irfft(result, fft_length=[2 * seq_len])[..., :seq_len]

m = tf.reshape(m, (-1, self.order * self.memory_d, seq_len))
m = tf.transpose(m, perm=[0, 2, 1])

# apply hidden cell
Expand Down
35 changes: 22 additions & 13 deletions keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,16 @@ def test_multivariate_lmu(rng):


@pytest.mark.parametrize("has_input_kernel", (True, False))
def test_layer_vs_cell(has_input_kernel, rng):
@pytest.mark.parametrize("fft", (True, False))
def test_layer_vs_cell(has_input_kernel, fft, rng):
n_steps = 10
input_d = 32
kwargs = dict(
memory_d=4 if has_input_kernel else input_d,
order=12,
theta=n_steps,
kernel_initializer="glorot_uniform" if has_input_kernel else None,
memory_to_memory=not fft,
)
hidden_cell = lambda: tf.keras.layers.SimpleRNNCell(units=64)

Expand All @@ -87,14 +89,17 @@ def test_layer_vs_cell(has_input_kernel, rng):
lmu_layer.layer.set_weights(lmu_cell.get_weights())
layer_out = lmu_layer(inp)

assert isinstance(lmu_layer.layer, layers.LMUFFT if fft else tf.keras.layers.RNN)

for w0, w1 in zip(
sorted(lmu_cell.weights, key=lambda w: w.shape.as_list()),
sorted(lmu_layer.weights, key=lambda w: w.shape.as_list()),
):
assert np.allclose(w0.numpy(), w1.numpy())

assert np.allclose(cell_out, lmu_cell(inp))
assert np.allclose(cell_out, layer_out)
atol = 2e-6 if fft else 1e-8
assert np.allclose(cell_out, lmu_cell(inp), atol=atol)
assert np.allclose(cell_out, layer_out, atol=atol)


def test_save_load_weights(rng, tmp_path):
Expand Down Expand Up @@ -181,18 +186,26 @@ def test_save_load_serialization(mode, tmp_path):

@pytest.mark.parametrize("return_sequences", (True, False))
@pytest.mark.parametrize(
"hidden_cell", (None, tf.keras.layers.Dense(4), tf.keras.layers.SimpleRNNCell(4))
"hidden_cell",
(
lambda: None,
lambda: tf.keras.layers.Dense(4),
lambda: tf.keras.layers.SimpleRNNCell(4),
),
)
def test_fft(return_sequences, hidden_cell, rng):
@pytest.mark.parametrize("memory_d", [1, 4])
def test_fft(return_sequences, hidden_cell, memory_d, rng):
kwargs = dict(memory_d=memory_d, order=2, theta=3, hidden_cell=hidden_cell())

x = rng.uniform(-1, 1, size=(2, 10, 32))

rnn_layer = tf.keras.layers.RNN(
layers.LMUCell(1, 2, 3, hidden_cell),
layers.LMUCell(**kwargs),
return_sequences=return_sequences,
)
rnn_out = rnn_layer(x)

fft_layer = layers.LMUFFT(1, 2, 3, hidden_cell, return_sequences=return_sequences)
fft_layer = layers.LMUFFT(return_sequences=return_sequences, **kwargs)
fft_layer.build(x.shape)
fft_layer.kernel.assign(rnn_layer.cell.kernel)
fft_out = fft_layer(x)
Expand Down Expand Up @@ -237,7 +250,7 @@ def test_fft_auto_swap(hidden_to_memory, memory_to_memory, memory_d, steps):
lmu.build((32, steps, 8))

assert isinstance(lmu.layer, tf.keras.layers.RNN) == (
hidden_to_memory or memory_to_memory or memory_d != 1 or steps is None
hidden_to_memory or memory_to_memory or steps is None
)


Expand Down Expand Up @@ -415,11 +428,7 @@ def test_fit(fft):

_, acc = model.evaluate(x_test, y_test, verbose=0)

if fft:
assert isinstance(lmu_layer.layer, layers.LMUFFT)
else:
assert isinstance(lmu_layer.layer, tf.keras.layers.RNN)

assert isinstance(lmu_layer.layer, layers.LMUFFT if fft else tf.keras.layers.RNN)
assert acc == 1.0


Expand Down

0 comments on commit 15d5ba9

Please sign in to comment.