Skip to content

Commit

Permalink
Add support for TensorFlow 2.9
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss authored and tbekolay committed May 10, 2022
1 parent 0db0e77 commit ff77285
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
3 changes: 3 additions & 0 deletions .nengobones.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ travis_yml:
env:
TF_VERSION: tensorflow==2.1.0
python: 3.6
- script: test
env:
TF_VERSION: tensorflow==2.9.0rc1 # TODO: change to 2.6 after 2.9 is released
- script: remote-docs
- script: remote-examples
pypi_user: __token__
Expand Down
4 changes: 4 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ jobs:
TF_VERSION="tensorflow==2.1.0"
SCRIPT="test"
python: 3.6
-
env:
TF_VERSION="tensorflow==2.9.0rc1"
SCRIPT="test"
-
env:
SCRIPT="remote-docs"
Expand Down
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ Release history
0.4.2 (unreleased)
==================

*Compatible with TensorFlow 2.1 - 2.9*

**Added**

- Added support for TensorFlow 2.9. (`#48`_)

.. _#48: https://github.com/nengo/keras-lmu/pull/48

0.4.1 (February 10, 2022)
=========================
Expand Down
8 changes: 5 additions & 3 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
# pylint: disable=ungrouped-imports
if version.parse(tf.__version__) < version.parse("2.6.0rc0"):
from tensorflow.python.keras.layers.recurrent import DropoutRNNCellMixin
else:
elif version.parse(tf.__version__) < version.parse("2.9.0rc0"):
from keras.layers.recurrent import DropoutRNNCellMixin
else:
from keras.layers.rnn.dropout_rnn_cell_mixin import DropoutRNNCellMixin

if version.parse(tf.__version__) < version.parse("2.8.0rc0"):
from tensorflow.keras.layers import Layer as BaseRandomLayer
Expand Down Expand Up @@ -272,7 +274,7 @@ def build(self, input_shape):
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(1, self.memory_d),
shape=(self.memory_d,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
)
Expand Down Expand Up @@ -903,7 +905,7 @@ def build(self, input_shape):
if self.use_bias:
self.bias = self.add_weight(
name="bias",
shape=(1, self.memory_d),
shape=(self.memory_d,),
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
)
Expand Down

0 comments on commit ff77285

Please sign in to comment.