Skip to content

Commit

Permalink
Sync OSS keras to head.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 366341827
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Apr 1, 2021
1 parent e2cd6f4 commit 661cbd3
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 185 deletions.
24 changes: 24 additions & 0 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,30 @@ class Callback:
... callbacks=[MyCallback()])
>>> assert training_finished == True
If you want to use `Callback` objects in a custom training loop:
1. You should pack all your callbacks into a single `callbacks.CallbackList`
so they can all be called together.
2. You will need to manually call all the `on_*` methods at the apropriate
locations in your loop. Like this:
```
callbacks = tf.keras.callbacks.CallbackList([...])
callbacks.append(...)
callbacks.on_train_begin(...)
for epoch in range(EPOCHS):
callbacks.on_epoch_begin(epoch)
for i, data in dataset.enumerate():
callbacks.on_train_batch_begin(i)
batch_logs = model.train_step(data)
callbacks.on_train_batch_end(i, batch_logs)
epoch_logs = ...
callbacks.on_epoch_end(epoch, epoch_logs)
final_logs=...
callbacks.on_train_end(final_logs)
```
Attributes:
params: Dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
Expand Down
10 changes: 9 additions & 1 deletion keras/distribute/custom_training_loop_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,21 @@ def test_custom_aggregation(self, distribution,
v = tf.Variable([0., 0.])
optimizer = gradient_descent.SGD(0.1)

class PerReplica(values.DistributedValues):
"""Holds a map from replica to unsynchronized values."""

@property
def values(self):
"""Returns the per replica values."""
return self._values

@tf.function
def optimize():
with tf.compat.v1.device(distribution.extended.worker_devices[0]):
v1 = tf.convert_to_tensor([1., 1.])
with tf.compat.v1.device(distribution.extended.worker_devices[1]):
v2 = tf.convert_to_tensor([2., 2.])
grads = values.PerReplica([v1, v2])
grads = PerReplica([v1, v2])
def step_fn(grads):
optimizer.apply_gradients(
[(grads, v)],
Expand Down
10 changes: 7 additions & 3 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import json
import os
import warnings
from tensorflow.python.distribute import values as ds_values
from tensorflow.python.eager import context
from keras import backend
from keras import callbacks as callbacks_module
Expand Down Expand Up @@ -2769,7 +2768,7 @@ def _reduce(v):
"""Reduce a single `PerReplica` object."""
if reduction == 'concat' and _collective_all_reduce_multi_worker(strategy):
return _multi_worker_concat(v, strategy)
if not isinstance(v, ds_values.PerReplica):
if not _is_per_replica_instance(v):
return v
elif reduction == 'first':
return strategy.unwrap(v)[0]
Expand Down Expand Up @@ -2823,7 +2822,7 @@ def _multi_worker_concat(v, strategy):
"""Order PerReplica objects for CollectiveAllReduceStrategy and concat."""
replicas = strategy.gather(v, axis=0)
# v might not have the same shape on different replicas
if isinstance(v, ds_values.PerReplica):
if _is_per_replica_instance(v):
shapes = tf.concat([
tf.expand_dims(tf.compat.v1.shape(single_value)[0], axis=0)
for single_value in v.values
Expand Down Expand Up @@ -2930,3 +2929,8 @@ def flatten_metrics_in_order(logs, metrics_names):
if len(results) == 1:
return results[0]
return results


def _is_per_replica_instance(obj):
return (isinstance(obj, tf.distribute.DistributedValues) and
isinstance(obj, tf.__internal__.CompositeTensor))
1 change: 0 additions & 1 deletion keras/layers/preprocessing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ py_library(
deps = [
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras:backend",
"//keras/utils:tf_utils",
],
)
Expand Down
14 changes: 3 additions & 11 deletions keras/layers/preprocessing/index_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,13 @@ def __init__(self,
value_index=value_index,
value_index_offset=self._token_start_index())

self._table = self._static_table_class()(
self._table = tf.lookup.StaticHashTable(
initializer, default_value=default_value)
self._table_handler = table_utils.TableHandler(
table=self._table,
mask_token=self._mask_key,
mask_value=self._mask_value,
oov_tokens=oov_indices,
use_v1_apis=self._use_v1_apis())
oov_tokens=oov_indices)

tracked_table = self._add_trackable(self._table, trainable=False)

Expand All @@ -293,8 +292,7 @@ def __init__(self,
name=(self._name + "_index_table"))
self._table_handler = table_utils.TableHandler(
table=self._table,
oov_tokens=oov_indices,
use_v1_apis=self._use_v1_apis())
oov_tokens=oov_indices)
if vocabulary is not None:
self.set_vocabulary(vocabulary)
tracked_table = self._add_trackable(self._table, trainable=False)
Expand Down Expand Up @@ -621,12 +619,6 @@ def call(self, inputs):
def _convert_to_ndarray(self, x):
return np.array(x) if isinstance(x, (list, tuple)) else x

def _use_v1_apis(self):
return False

def _static_table_class(self):
return tf.lookup.StaticHashTable

def _oov_start_index(self):
return 1 if self.mask_token is not None and self.output_mode == INT else 0

Expand Down
Loading

0 comments on commit 661cbd3

Please sign in to comment.