Skip to content
This repository has been archived by the owner on Dec 9, 2024. It is now read-only.

Commit

Permalink
Increase compatibility with Keras 3.x
Browse files Browse the repository at this point in the history
Signed-off-by: GitHub <noreply@github.com>
  • Loading branch information
samzong authored Oct 21, 2024
1 parent 8fb7907 commit 2fe9840
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions scripts/tf_cnn_benchmarks/models/experimental/deepspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,24 @@ def decode_logits(self, logits):
class DeepSpeech2Model(model_lib.Model):
"""Define DeepSpeech2 model."""

# Supported rnn cells.
SUPPORTED_RNNS = {
'lstm': tf.nn.rnn_cell.BasicLSTMCell,
'rnn': tf.nn.rnn_cell.RNNCell,
'gru': tf.nn.rnn_cell.GRUCell,
}
# Check TensorFlow Keras version.
keras_version = tf.keras.__version__.split('.')
major_version = int(keras_version[0])

if major_version >= 3:
# Supported rnn cells for Keras 3.x
SUPPORTED_RNNS = {
'lstm': tf.keras.layers.LSTM,
'rnn': tf.keras.layers.SimpleRNN,
'gru': tf.keras.layers.GRU,
}
else:
# Supported rnn cells for Keras versions below 3.x
SUPPORTED_RNNS = {
'lstm': tf.nn.rnn_cell.BasicLSTMCell,
'rnn': tf.nn.rnn_cell.RNNCell,
'gru': tf.nn.rnn_cell.GRUCell,
}

# Parameters for batch normalization.
BATCH_NORM_EPSILON = 1e-5
Expand Down

0 comments on commit 2fe9840

Please sign in to comment.