Skip to content

Commit

Permalink
Sync OSS keras to head.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 362969794
  • Loading branch information
qlzh727 authored and tensorflower-gardener committed Mar 15, 2021
1 parent 33f7c97 commit 7f8c62b
Show file tree
Hide file tree
Showing 20 changed files with 135 additions and 32 deletions.
2 changes: 0 additions & 2 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,6 @@ def __init__(self,
self.embeddings_freq = embeddings_freq
self.embeddings_metadata = embeddings_metadata
self._init_profile_batch(profile_batch)
self._epoch = 0
self._global_train_batch = 0
self._previous_epoch_iterations = 0
self._train_accumulated_time = 0
Expand Down Expand Up @@ -2397,7 +2396,6 @@ def on_train_batch_end(self, batch, logs=None):

def on_epoch_begin(self, epoch, logs=None):
# Keeps track of epoch for profiling.
self._epoch = epoch
if self.write_steps_per_second:
self._previous_epoch_iterations = self.model.optimizer.iterations.numpy()
self._train_accumulated_time = 0
Expand Down
1 change: 0 additions & 1 deletion keras/callbacks_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,6 @@ def on_epoch_begin(self, epoch, logs=None):

# check if histogram summary should be run for this epoch
if self.histogram_freq and epoch % self.histogram_freq == 0:
self._epoch = epoch
# pylint: disable=protected-access
# add the histogram summary op if it should run this epoch
self.model._make_test_function()
Expand Down
2 changes: 1 addition & 1 deletion keras/engine/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2836,7 +2836,7 @@ def _flatten_modules(self, recursive=True, include_self=True):
None)
if subtrackables:
deque.extendleft(reversed(subtrackables))
elif isinstance(trackable_obj, data_structures.TrackableDataStructure):
elif isinstance(trackable_obj, tf.__internal__.tracking.TrackableDataStructure):
# Data structures are introspected even with `recursive=False`.
tracked_values = trackable_obj._values
if tracked_values:
Expand Down
15 changes: 6 additions & 9 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
Expand Down Expand Up @@ -318,7 +317,7 @@ def __setattr__(self, name, value):

if all(
isinstance(v, (base_layer.Layer,
data_structures.TrackableDataStructure)) or
tf.__internal__.tracking.TrackableDataStructure)) or
base_layer_utils.has_weights(v) for v in tf.nest.flatten(value)):
try:
self._base_model_initialized
Expand Down Expand Up @@ -935,13 +934,11 @@ def fit(self,
noise and dropout.
`validation_data` will override `validation_split`.
`validation_data` could be:
- tuple `(x_val, y_val)` of Numpy arrays or tensors
- tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
- dataset
For the first two cases, `batch_size` must be provided.
For the last case, `validation_steps` could be provided.
Note that `validation_data` does not support all the data types that
are supported in `x`, eg, dict, generator or `keras.utils.Sequence`.
- A tuple `(x_val, y_val)` of Numpy arrays or tensors.
- A tuple `(x_val, y_val, val_sample_weights)` of NumPy arrays.
- A `tf.data.Dataset`.
- A Python generator or `keras.utils.Sequence` returning
`(inputs, targets)` or `(inputs, targets, sample_weights)`.
shuffle: Boolean (whether to shuffle the training data
before each epoch) or str (for 'batch'). This argument is ignored
when `x` is a generator or an object of tf.data.Dataset.
Expand Down
1 change: 1 addition & 0 deletions keras/integration_test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ tf_py_test(
tags = [
"no_tfrt", # TODO(b/171765113)
"noasan", # TODO(b/156029134)
"nomac", # TODO(b/182567880)
"nomsan", # TODO(b/156029134)
"notsan", # TODO(b/156029134)
],
Expand Down
6 changes: 5 additions & 1 deletion keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,8 @@ class Dense(Layer):
where `activation` is the element-wise activation function
passed as the `activation` argument, `kernel` is a weights matrix
created by the layer, and `bias` is a bias vector created by the layer
(only applicable if `use_bias` is `True`).
(only applicable if `use_bias` is `True`). These are all attributes of
`Dense`.
Note: If the input to the layer has a rank greater than 2, then `Dense`
computes the dot product between the `inputs` and the `kernel` along the
Expand All @@ -1086,6 +1087,9 @@ class Dense(Layer):
Besides, layer attributes cannot be modified after the layer has been called
once (except the `trainable` attribute).
When a popular kwarg `input_shape` is passed, then keras will create
an input layer to insert before the current layer. This can be treated
equivalent to explicitly defining an `InputLayer`.
Example:
Expand Down
4 changes: 2 additions & 2 deletions keras/layers/normalization_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class BatchNormalization(normalization.BatchNormalizationBase):
with the argument `training=True`), the layer normalizes its output using
the mean and standard deviation of the current batch of inputs. That is to
say, for each channel being normalized, the layer returns
`(batch - mean(batch)) / (var(batch) + epsilon) * gamma + beta`, where:
`gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where:
- `epsilon` is small constant (configurable as part of the constructor
arguments)
Expand All @@ -212,7 +212,7 @@ class BatchNormalization(normalization.BatchNormalizationBase):
default), the layer normalizes its output using a moving average of the
mean and standard deviation of the batches it has seen during training. That
is to say, it returns
`(batch - self.moving_mean) / (self.moving_var + epsilon) * gamma + beta`.
`gamma * (batch - self.moving_mean) / sqrt(self.moving_var + epsilon) + beta`.
`self.moving_mean` and `self.moving_var` are non-trainable variables that
are updated each time the layer in called in training mode, as such:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

import itertools
import time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

import time

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

import time

Expand Down
2 changes: 1 addition & 1 deletion keras/layers/preprocessing/benchmarks/hashing_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

import itertools
import random
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

import functools
import time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

import collections
import itertools
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

import os
import random
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v2 as tf
import tensorflow as tf

import time

Expand Down
25 changes: 23 additions & 2 deletions keras/mixed_precision/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,27 @@ def __setattr__(self, name, value):
else:
super(LossScaleOptimizer, self).__setattr__(name, value)

# Explicitly delegate learning_rate. Normally hyperparameters are delegated in
# __getattribute__, but if a hyperparameter is not in self._optimizer._hyper
# (e.g. because self._optimizer itself wraps another optimizer), then it won't
# be delegated. Since learning_rate is a very commonly accessed
# hyperparameter, we delegate it here.
@property
def learning_rate(self):
return self._optimizer.learning_rate

@learning_rate.setter
def learning_rate(self, value):
self._optimizer.learning_rate = value

@property
def lr(self):
return self._optimizer.learning_rate

@lr.setter
def lr(self, value):
self._optimizer.lr = value

# We do not override some OptimizerV2 methods. For each, we describe why we do
# not delegate them to self._optimizer:
# * get_updates: get_updates() calls get_gradients(). Since we override
Expand All @@ -933,8 +954,8 @@ def __setattr__(self, name, value):
class LossScaleOptimizerV1(LossScaleOptimizer):
"""An deprecated optimizer that applies loss scaling.
Warning: This class is deprecated and will be removed in TensorFlow 2.5.
Please use the non-experimental class
Warning: This class is deprecated and will be removed in a future version of
TensorFlow. Please use the non-experimental class
`tf.keras.mixed_precision.LossScaleOptimizer` instead.
This class is identical to the non-experimental
Expand Down
38 changes: 38 additions & 0 deletions keras/mixed_precision/loss_scale_optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from keras.mixed_precision import test_util as mp_test_util
from keras.optimizer_v2 import adam
from keras.optimizer_v2 import gradient_descent
from keras.optimizer_v2 import optimizer_v2

# Disable not-callable lint error, as the linter is unable to detect that
# LossScale instances are callable.
Expand Down Expand Up @@ -641,6 +642,43 @@ def get_config(self):
'DynamicLossScale is no longer supported. Got:'):
loss_scale_optimizer.LossScaleOptimizerV1(opt, MyLossScale())

def testLossScaleDelegationWithWrapper(self):
# Test learning_rate is exposed when LossScaleOptimizer wraps another
# wrapper.

class MyOptimizer(optimizer_v2.OptimizerV2):

def __init__(self):
super().__init__('MyOptimizer')
self.inner_optimizer = adam.Adam(learning_rate=1.0)

@property
def learning_rate(self):
return self.inner_optimizer.learning_rate

@learning_rate.setter
def learning_rate(self, value):
self.inner_optimizer.learning_rate = value

def get_config(self):
return {}

with self.cached_session():
opt = MyOptimizer()
opt = loss_scale_optimizer.LossScaleOptimizer(opt)

# Force hyperparameters to be created
opt.learning_rate # pylint: disable=pointless-statement
self.evaluate(tf.compat.v1.global_variables_initializer())

self.assertEqual(self.evaluate(opt.learning_rate), 1.0)
self.assertEqual(
self.evaluate(opt.inner_optimizer.inner_optimizer.learning_rate), 1.0)
opt.learning_rate = 2.0
self.assertEqual(self.evaluate(opt.learning_rate), 2.0)
self.assertEqual(self.evaluate(
opt.inner_optimizer.inner_optimizer.learning_rate), 2.0)

@parameterized.named_parameters({
'testcase_name': 'SaveAndRestoreBase',
'strategy_fn': default_strategy_fn,
Expand Down
4 changes: 4 additions & 0 deletions keras/optimizer_v2/optimizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import abc
import contextlib
import functools
import warnings

import six
from tensorflow.python.distribute import values as ds_values
Expand Down Expand Up @@ -360,6 +361,9 @@ def my_gradient_transformer(grads_and_vars):
# checks that all keyword arguments are non-negative.
if kwargs[k] is not None and kwargs[k] < 0:
raise ValueError("Expected {} >= 0, received: {}".format(k, kwargs[k]))
if k == "lr":
warnings.warn(
"The `lr` argument is deprecated, use `learning_rate` instead.")

self._use_locking = True
self._init_set_name(name)
Expand Down
3 changes: 1 addition & 2 deletions keras/saving/saved_model/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures

# To avoid circular dependencies between keras/engine and keras/saving,
# code in keras/saving must delay imports.
Expand Down Expand Up @@ -350,7 +349,7 @@ def _add_children_recreated_from_config(self, obj, proto, node_id):
child_proto.variable.name):
obj_child._handle_name = child_proto.variable.name + ':0' # pylint: disable=protected-access

if isinstance(obj_child, data_structures.TrackableDataStructure):
if isinstance(obj_child, tf.__internal__.tracking.TrackableDataStructure):
setter = lambda *args: None

child_path = '{}.{}'.format(parent_path, child_name)
Expand Down
50 changes: 46 additions & 4 deletions keras/saving/saved_model_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.util.tf_export import keras_export

Expand Down Expand Up @@ -143,16 +142,16 @@ def _export_model_json(model, saved_model_path):
"""Saves model configuration as a json string under assets folder."""
model_json = model.to_json()
model_json_filepath = os.path.join(
saved_model_utils.get_or_create_assets_dir(saved_model_path),
_get_or_create_assets_dir(saved_model_path),
tf.compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
with tf.io.gfile.GFile(model_json_filepath, 'w') as f:
f.write(model_json)


def _export_model_variables(model, saved_model_path):
"""Saves model weights in checkpoint format under variables folder."""
saved_model_utils.get_or_create_variables_dir(saved_model_path)
checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path)
_get_or_create_variables_dir(saved_model_path)
checkpoint_prefix = _get_variables_path(saved_model_path)
model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
return checkpoint_prefix

Expand Down Expand Up @@ -422,3 +421,46 @@ def load_from_saved_model(saved_model_path, custom_objects=None):
tf.compat.as_text(tf.saved_model.VARIABLES_FILENAME))
model.load_weights(checkpoint_prefix)
return model


#### Directory / path helpers


def _get_or_create_variables_dir(export_dir):
"""Return variables sub-directory, or create one if it doesn't exist."""
variables_dir = _get_variables_dir(export_dir)
if not tf.compat.v1.gfile.Exists(variables_dir):
tf.compat.v1.gfile.MakeDirs(variables_dir)
return variables_dir


def _get_variables_dir(export_dir):
"""Return variables sub-directory in the SavedModel."""
return os.path.join(
tf.compat.as_text(export_dir),
tf.compat.as_text(tf.saved_model.VARIABLES_DIRECTORY))


def _get_variables_path(export_dir):
"""Return the variables path, used as the prefix for checkpoint files."""
return os.path.join(
tf.compat.as_text(_get_variables_dir(export_dir)),
tf.compat.as_text(tf.saved_model.VARIABLES_FILENAME))


def _get_or_create_assets_dir(export_dir):
"""Return assets sub-directory, or create one if it doesn't exist."""
assets_destination_dir = _get_assets_dir(export_dir)

if not tf.compat.v1.gfile.Exists(assets_destination_dir):
tf.compat.v1.gfile.MakeDirs(assets_destination_dir)

return assets_destination_dir


def _get_assets_dir(export_dir):
"""Return path to asset directory in the SavedModel."""
return os.path.join(
tf.compat.as_text(export_dir),
tf.compat.as_text(tf.saved_model.ASSETS_DIRECTORY))

0 comments on commit 7f8c62b

Please sign in to comment.