Skip to content

Commit

Permalink
Fix static checks
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Oct 31, 2022
1 parent 49fc752 commit 71aae6d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
12 changes: 5 additions & 7 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Core classes for the KerasLMU package.
"""
"""Core classes for the KerasLMU package."""

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -180,8 +178,8 @@ def theta(self):
"""
Value of the ``theta`` parameter.
If ``trainable_theta=True`` this returns the trained value, not the initial
value passed in to the constructor.
If ``trainable_theta=True`` this returns the trained value, not the
initial value passed in to the constructor.
"""
if self.built:
return 1 / tf.keras.backend.get_value(self.theta_inv)
Expand Down Expand Up @@ -583,8 +581,8 @@ def theta(self):
"""
Value of the ``theta`` parameter.
If ``trainable_theta=True`` this returns the trained value, not the initial
value passed in to the constructor.
If ``trainable_theta=True`` this returns the trained value, not the
initial value passed in to the constructor.
"""

if self.built:
Expand Down
2 changes: 1 addition & 1 deletion keras_lmu/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def on_predict_batch_end(self, batch, logs=None):
@pytest.mark.skipif(not tf_gpu_installed, reason="Very slow on CPU")
@pytest.mark.parametrize(
"mode, min_time, max_time",
[("rnn", 0.1, 0.2), ("fft", 0.1, 0.2), ("raw", 0.05, 0.15)],
[("rnn", 0.1, 0.2), ("fft", 0.05, 0.15), ("raw", 0.05, 0.15)],
)
def test_performance(mode, min_time, max_time):
# performance is based on Azure NC6 VM
Expand Down
4 changes: 3 additions & 1 deletion keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def test_layer_vs_cell(rng, has_input_kernel, feedforward, discretizer):
kernel_initializer="glorot_uniform" if has_input_kernel else None,
memory_to_memory=not feedforward,
)
hidden_cell = lambda: tf.keras.layers.SimpleRNNCell(units=64)

def hidden_cell():
return tf.keras.layers.SimpleRNNCell(units=64)

inp = rng.uniform(-1, 1, size=(2, n_steps, input_d))

Expand Down

0 comments on commit 71aae6d

Please sign in to comment.