Skip to content

Commit

Permalink
Sync OSS keras to head.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 366714851
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Apr 4, 2021
1 parent 8000a3e commit 8e95a38
Show file tree
Hide file tree
Showing 11 changed files with 203 additions and 83 deletions.
1 change: 1 addition & 0 deletions keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ tf_py_test(
shard_count = 6,
tags = [
"no_oss",
"no_tfrt", # TODO(b/179690526)
"notsan",
],
deps = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ tf_class {
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'input_options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
}
3 changes: 1 addition & 2 deletions keras/distribute/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ distribute_py_test(
name = "parameter_server_training_test",
srcs = ["parameter_server_training_test.py"],
python_version = "PY3",
shard_count = 1,
shard_count = 4, # TODO(b/184290570): Investigate why only 1 shard times out.
tags = [
"multi_and_single_gpu",
"nomultivm", # TODO(b/170502145)
Expand All @@ -746,7 +746,6 @@ distribute_py_test(
"multi_gpu",
"no_oss", # TODO(b/183640564): Reenable
"nomultivm", # TODO(b/170502145)
"notsan", # TODO(b/184374640)
],
deps = [
":multi_worker_testing_utils",
Expand Down
93 changes: 34 additions & 59 deletions keras/distribute/multi_worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@

# pylint: disable=g-direct-tensorflow-import
import keras
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import multi_worker_test_base as test_base
from keras import backend
from keras import callbacks
from keras import metrics as metrics_module
Expand Down Expand Up @@ -120,7 +118,7 @@ def __init__(self, num_epoch, num_worker):
self._num_worker = num_worker
self._task_dict = {
key: collections.defaultdict(lambda: collections.defaultdict(int))
for key in ['ps', 'worker']
for key in ['ps', 'worker', 'chief']
}
self._lock = threading.Lock()
self._is_between_graph = None
Expand Down Expand Up @@ -168,75 +166,52 @@ def verify(self, test_case):
else:
# If in-graph, only the first worker calls callback methods.
worker_call_count = {0: method_count_dict}
chief_call_count = {0: method_count_dict}
task_config = json.loads(os.environ['TF_CONFIG'])['task']['type']
test_case.assertDictEqual(
self._task_dict,
{
# PS' callback is not supposed to be called.
'ps': {},
# Each of the Worker should be called num_epoch of times.
'worker': worker_call_count
# Worker or chief should only be called on worker/chief.
'worker': worker_call_count if task_config == 'worker' else {},
'chief': chief_call_count if task_config == 'chief' else {}
})


class KerasMultiWorkerTestIndependentWorker(test_base.IndependentWorkerTestBase,
class KerasMultiWorkerTestIndependentWorker(tf.test.TestCase,
parameterized.TestCase):

@tf.__internal__.distribute.combinations.generate(
tf.__internal__.test.combinations.combine(
mode=['graph'],
strategy_cls=[
tf.distribute.MultiWorkerMirroredStrategy,
],
required_gpus=[0, 1]))
def testSimpleModelIndependentWorkerSync(self, strategy_cls):
num_workers = 2
num_epoch = 2

cluster_spec = tf.__internal__.distribute.multi_process_runner.create_cluster_spec(num_workers=num_workers)
self._barrier = dc._Barrier(2)

# The verification callback will be shared by multiple threads.
mode=['eager'],
strategy=[
tf.__internal__.distribute.combinations.multi_worker_mirrored_2x1_cpu,
tf.__internal__.distribute.combinations.multi_worker_mirrored_2x1_gpu,
]))
def testSimpleModelIndependentWorkerSync(self, strategy):
verification_callback = MultiWorkerVerificationCallback(
num_epoch=num_epoch, num_worker=num_workers)

def _independent_worker_fn(*args, **kwargs): # pylint: disable=unused-argument
"""Simulates an Independent Worker inside of a thread."""
with tf.compat.v1.test.mock.patch.object(dc, '_run_std_server',
self._make_mock_run_std_server()):
strategy = strategy_cls()
verification_callback.is_between_graph = \
strategy.extended.experimental_between_graph
batch_size = 64
steps = 2
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
batch_size, steps)
with strategy.scope():
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
orig_loss, _ = model.evaluate(train_ds, steps=steps)
callbacks_for_fit = tf.nest.flatten(
kwargs.get('verification_callback', []))
history = model.fit(
x=train_ds,
epochs=num_epoch,
steps_per_epoch=steps,
callbacks=callbacks_for_fit)
self.assertIsInstance(history, keras.callbacks.History)
trained_loss, _ = model.evaluate(train_ds, steps=steps)
self.assertLess(trained_loss, orig_loss)

threads = self.run_multiple_tasks_in_threads(
_independent_worker_fn,
cluster_spec,
verification_callback=verification_callback)

threads_to_join = []
strategy = strategy_cls()
if strategy.extended.experimental_between_graph:
for ts in threads.values():
threads_to_join.extend(ts)
else:
threads_to_join = [threads['worker'][0]]
self.join_independent_workers(threads_to_join)
num_epoch=2,
num_worker=len(
json.loads(os.environ['TF_CONFIG'])['cluster']['worker']))
verification_callback.is_between_graph = \
strategy.extended.experimental_between_graph
batch_size = 64
steps = 2
train_ds, _ = multi_worker_testing_utils.mnist_synthetic_dataset(
batch_size, steps)
with strategy.scope():
model = multi_worker_testing_utils.get_mnist_model((28, 28, 1))
orig_loss, _ = model.evaluate(train_ds, steps=steps)
history = model.fit(
x=train_ds,
epochs=2,
steps_per_epoch=steps,
callbacks=[verification_callback])
self.assertIsInstance(history, keras.callbacks.History)
trained_loss, _ = model.evaluate(train_ds, steps=steps)
self.assertLess(trained_loss, orig_loss)

verification_callback.verify(self)


Expand Down
3 changes: 0 additions & 3 deletions keras/distribute/multi_worker_testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,4 @@ def get_mnist_model(input_shape):
def make_parameter_server_cluster(num_workers, num_ps):
cluster_def = multi_worker_test_base.create_in_process_cluster(
num_workers=num_workers, num_ps=num_ps, rpc_layer="grpc")
cluster_def["chief"] = [
"localhost:%d" % multi_worker_test_base.pick_unused_port()
]
return SimpleClusterResolver(ClusterSpec(cluster_def), rpc_layer="grpc")
6 changes: 4 additions & 2 deletions keras/engine/data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ def get_size(self):
return None # To be inferred by `DataHandler`.

def get_dataset(self):
return self.strategy.distribute_datasets_from_function(self.dataset_creator)
return self.strategy.distribute_datasets_from_function(
self.dataset_creator, options=self.dataset_creator.input_options)

def batch_size(self):
raise NotImplementedError()
Expand Down Expand Up @@ -1329,7 +1330,8 @@ def _configure_dataset_and_inferred_steps(self, strategy, x, steps_per_epoch,
"`DatasetCreator`.")

def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(x)
return strategy.distribute_datasets_from_function(
x, options=x.input_options)

self._dataset = self._model._cluster_coordinator.create_per_worker_dataset( # pylint: disable=protected-access
per_worker_dataset_fn)
Expand Down
59 changes: 54 additions & 5 deletions keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,13 +1196,15 @@ def mean_squared_error(y_true, y_pred):
return backend.mean(tf.math.squared_difference(y_pred, y_true), axis=-1)


def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred):
def _ragged_tensor_apply_loss(loss_fn, y_true, y_pred, y_pred_extra_dim=False):
"""Apply a loss function on a per batch basis.
Args:
loss_fn: The loss function
y_true: truth values (RaggedTensor)
y_pred: predicted values (RaggedTensor)
y_pred_extra_dim: whether y_pred has an additional dimension compared to
y_true
Returns:
Loss-function result. A dense tensor if the output has a single dimension
Expand All @@ -1225,27 +1227,51 @@ def rt_is_equiv_dense(rt):
])

def _convert_to_dense(inputs):
return tuple(rt.to_tensor() for rt in inputs)
return tuple(
rt.to_tensor() if isinstance(rt, tf.RaggedTensor) else rt
for rt in inputs)

def _wrapper(inputs):
def _call_loss(inputs, ragged_output):
""" Adapt the result to ragged or dense tensor according to the expected
output type. This is done so that all the return values of the map
operation have the same type.
"""
r = loss_fn(*inputs)
if ragged_output and not isinstance(r, tf.RaggedTensor):
r = tf.RaggedTensor.from_tensor(r)
elif not ragged_output and isinstance(r, tf.RaggedTensor):
r = r.to_tensor()
return r

def _wrapper(inputs, ragged_output):
_, y_pred = inputs
if isinstance(y_pred, tf.RaggedTensor):
return tf.compat.v1.cond(
rt_is_equiv_dense(y_pred),
lambda: loss_fn(*_convert_to_dense(inputs)), lambda: loss_fn(*inputs))
lambda: _call_loss(_convert_to_dense(inputs), ragged_output),
lambda: _call_loss(inputs, ragged_output))

return loss_fn(*inputs)

if not isinstance(y_true, tf.RaggedTensor):
return loss_fn(y_true, y_pred.to_tensor())

lshape = y_pred.shape.as_list()[1:-1]
if len(lshape) > 0:
spec = tf.RaggedTensorSpec(shape=lshape, dtype=y_pred.dtype)
else:
spec = tf.TensorSpec(shape=[], dtype=y_pred.dtype)

nested_splits_list = [rt.nested_row_splits for rt in (y_true, y_pred)]
if y_pred_extra_dim:
nested_splits_list[1] = nested_splits_list[1][:-1]

map_fn = functools.partial(_wrapper, ragged_output=len(lshape) > 1)

assertion_list = ragged_util.assert_splits_match(nested_splits_list)
with tf.control_dependencies(assertion_list):
return ragged_map_ops.map_fn(_wrapper, elems=(y_true, y_pred), dtype=spec)
return ragged_map_ops.map_fn(map_fn, elems=(y_true, y_pred), dtype=spec)


@dispatch.dispatch_for_types(mean_squared_error, tf.RaggedTensor)
Expand Down Expand Up @@ -1694,6 +1720,29 @@ def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
y_true, y_pred, from_logits=from_logits, axis=axis)


@dispatch.dispatch_for_types(sparse_categorical_crossentropy,
tf.RaggedTensor)
def _ragged_tensor_sparse_categorical_crossentropy(y_true,
y_pred,
from_logits=False,
axis=-1):
""" Implements support for handling RaggedTensors.
Expected y_pred shape: (batch, sequence_len, n_classes) with sequence_len
being variable per batch.
Return shape: (batch, sequence_len).
When used by SparseCategoricalCrossentropy() with the default reduction
(SUM_OVER_BATCH_SIZE), the reduction averages the loss over the
number of elements independent of the batch. E.g. if the RaggedTensor
has 2 batches with [2, 1] values respectively, the resulting loss is
the sum of the individual loss values divided by 3.
"""
fn = functools.partial(
sparse_categorical_crossentropy, from_logits=from_logits, axis=axis)
return _ragged_tensor_apply_loss(fn, y_true, y_pred, y_pred_extra_dim=True)


@keras_export('keras.metrics.binary_crossentropy',
'keras.losses.binary_crossentropy')
@tf.__internal__.dispatch.add_dispatch_support
Expand Down
30 changes: 30 additions & 0 deletions keras/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,36 @@ def test_non_tensor(self):
loss = cce_obj(y_true, y_pred, sample_weight=2.3)
self.assertAlmostEqual(self.evaluate(loss), .7449, 3)

def test_ragged_tensors(self):
cce_obj = losses.SparseCategoricalCrossentropy()
y_true = tf.ragged.constant([[0, 1], [2]])
y_pred = tf.ragged.constant(
[[[.9, .05, .05], [.5, .89, .6]], [[.05, .01, .94]]],
dtype=tf.float32)
# batch losses [[0.1054, 0.8047], [0.0619]]
sample_weight = tf.constant([[1.2], [3.4]], shape=(2, 1))
loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)
# sum([0.1054, 0.8047, 0.0619]) / 3
self.assertAlmostEqual(self.evaluate(loss), 0.4341, 3)

# Test with logits.
logits = tf.ragged.constant([[[8., 1., 1.], [0., 9., 1.]],
[[2., 3., 5.]]])
cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True)
# batch losses [[0.0018, 0.0004], [0.1698]]
loss = cce_obj(y_true, logits, sample_weight=sample_weight)
self.assertAlmostEqual(self.evaluate(loss), 0.1934, 3)

def test_ragged_tensors_3d(self):
# shape [2, 1, None]
y_true = tf.ragged.constant([[[1, 1]], [[0]]])
# shape [2, 1, None, 2]
y_pred = tf.ragged.constant([[[[0.1, 0.9], [0.1, 0.9]]],
[[[0.9, 0.1]]]])
cce_obj = losses.SparseCategoricalCrossentropy()
loss = cce_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), 0.1054, 3)


@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class HingeTest(tf.test.TestCase):
Expand Down
1 change: 1 addition & 0 deletions keras/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ tf_py_test(
":dataset_creator",
"//:expect_portpicker_installed",
"//:expect_tensorflow_installed",
"//keras:combinations",
"//keras/engine",
"//keras/layers:core",
],
Expand Down
20 changes: 18 additions & 2 deletions keras/utils/dataset_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def dataset_fn(input_context):
dataset = dataset.prefetch(2)
return dataset
model.fit(DatasetCreator(dataset_fn), epochs=10, steps_per_epoch=10)
input_options = tf.distribute.InputOptions(
experimental_fetch_to_device=True,
experimental_per_replica_buffer_size=2)
model.fit(DatasetCreator(dataset_fn, input_options=input_options),
epochs=10, steps_per_epoch=10)
```
`Model.fit` usage with `DatasetCreator` is intended to work across all
Expand All @@ -67,12 +71,24 @@ def dataset_fn(input_context):
cross-worker input pipeline sharding (if neither is needed, the
`InputContext` parameter can be ignored in the `dataset_fn`), and returns
a `tf.data.Dataset`.
input_options: Optional `tf.distribute.InputOptions`, used for specific
options when used with distribution, for example, whether to prefetch
dataset elements to accelerator device memory or host device memory, and
prefetch buffer size in the replica device memory. No effect if not used
with distributed training. See `tf.distribute.InputOptions` for more
information.
"""

def __init__(self, dataset_fn):
def __init__(self, dataset_fn, input_options=None):
if not callable(dataset_fn):
raise TypeError('`dataset_fn` for `DatasetCreator` must be a `callable`.')
if input_options and (not isinstance(input_options,
tf.distribute.InputOptions)):
raise TypeError('`input_options` for `DatasetCreator` must be a '
'`tf.distribute.InputOptions`.')

self.dataset_fn = dataset_fn
self.input_options = input_options

def __call__(self, *args, **kwargs):
# When a `DatasetCreator` is invoked, it forwards args/kwargs straight to
Expand Down
Loading

0 comments on commit 8e95a38

Please sign in to comment.