diff --git a/.nengobones.yml b/.nengobones.yml index b0fd880c..aa02e850 100644 --- a/.nengobones.yml +++ b/.nengobones.yml @@ -73,6 +73,9 @@ travis_yml: env: TF_VERSION: tensorflow==2.1.0 python: 3.6 + - script: test + env: + TF_VERSION: tensorflow==2.9.0rc1 # TODO: change to 2.6 after 2.9 is released - script: remote-docs - script: remote-examples pypi_user: __token__ diff --git a/.travis.yml b/.travis.yml index a856d271..11e06e6a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -37,6 +37,10 @@ jobs: TF_VERSION="tensorflow==2.1.0" SCRIPT="test" python: 3.6 + - + env: + TF_VERSION="tensorflow==2.9.0rc1" + SCRIPT="test" - env: SCRIPT="remote-docs" diff --git a/CHANGES.rst b/CHANGES.rst index 9e8eb602..e0239e71 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -22,6 +22,13 @@ Release history 0.4.2 (unreleased) ================== +*Compatible with TensorFlow 2.1 - 2.9* + +**Added** + +- Added support for TensorFlow 2.9. (`#48`_) + +.. _#48: https://github.com/nengo/keras-lmu/pull/48 0.4.1 (February 10, 2022) ========================= diff --git a/keras_lmu/layers.py b/keras_lmu/layers.py index 729e7add..bafce9f4 100644 --- a/keras_lmu/layers.py +++ b/keras_lmu/layers.py @@ -9,8 +9,10 @@ # pylint: disable=ungrouped-imports if version.parse(tf.__version__) < version.parse("2.6.0rc0"): from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin -else: +elif version.parse(tf.__version__) < version.parse("2.9.0rc0"): from keras.layers.recurrent import DropoutRNNCellMixin +else: + from keras.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin if version.parse(tf.__version__) < version.parse("2.8.0rc0"): from tensorflow.keras.layers import Layer as BaseRandomLayer @@ -272,7 +274,7 @@ def build(self, input_shape): if self.use_bias: self.bias = self.add_weight( name="bias", - shape=(1, self.memory_d), + shape=(self.memory_d,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, ) @@ -903,7 +905,7 @@ def build(self, input_shape): if self.use_bias: self.bias = self.add_weight( name="bias", - shape=(1, self.memory_d), + shape=(self.memory_d,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, )