-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: Cluster setup for MultiWorkerMirroredStrategy (#415)
* Feature: Cluster setup for MultiWorkerMirroredStrategy * Configuring tests to use the new hyperparameter for MWMS * Black formatted files * fixing failing tests * Removing references to py versions older than py37 * Converting py36 tests to py37 * fix: linting and changed variable name to sagemaker_multi_worker_mirrored_strategy_enabled * fix: feezing protobuf version * fix: renaming MWMS variable name * fix: rename functions for _mwm to _mwms * Revert "fix: feezing protobuf version" This reverts commit c3e6819. * Revert "Converting py36 tests to py37" This reverts commit 86701b4. * Revert "Removing references to py versions older than py37" This reverts commit 718e5c7. * fix: variable name changes for MWMS * fix: renaming training script to train_dummy.py * fix: freezing latest sagemaker toolkit version * trigger ci * fix: adding epochs and steps to failing MWMS test * fix: changing MWMS testcase * fix: logic error in MWMS * fix: logic error in MWMS * fix: Updating MWMS tests to check for log lines * fix: linting * trigger ci * trigger ci Co-authored-by: Nishanth Hegde <hegdn@amazon.com>
- Loading branch information
Showing
6 changed files
with
200 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright 2017-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
from __future__ import absolute_import | ||
|
||
import os | ||
|
||
from sagemaker.tensorflow import TensorFlow | ||
from sagemaker.utils import unique_name_from_base | ||
|
||
|
||
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources") | ||
|
||
|
||
def test_multi_node(sagemaker_session, instance_type, image_uri, tmpdir, framework_version, capsys): | ||
estimator = TensorFlow( | ||
entry_point=os.path.join(RESOURCE_PATH, "multi_worker_mirrored", "train_dummy.py"), | ||
role="SageMakerRole", | ||
instance_type=instance_type, | ||
instance_count=2, | ||
image_name=image_uri, | ||
framework_version=framework_version, | ||
py_version="py3", | ||
hyperparameters={ | ||
"sagemaker_multi_worker_mirrored_strategy_enabled": True, | ||
}, | ||
sagemaker_session=sagemaker_session, | ||
) | ||
estimator.fit(job_name=unique_name_from_base("test-tf-mwms")) | ||
captured = capsys.readouterr() | ||
logs = captured.out + captured.err | ||
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs | ||
assert "TF_CONFIG=" in logs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
from __future__ import absolute_import |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Please refer to https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_keras.ipynb | ||
|
||
import tensorflow as tf | ||
import numpy as np | ||
import os | ||
import json | ||
|
||
|
||
def mnist_dataset(batch_size): | ||
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data() | ||
# The `x` arrays are in uint8 and have values in the [0, 255] range. | ||
# You need to convert them to float32 with values in the [0, 1] range. | ||
x_train = x_train / np.float32(255) | ||
y_train = y_train.astype(np.int64) | ||
train_dataset = tf.data.Dataset.from_tensor_slices( | ||
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size) | ||
return train_dataset | ||
|
||
def build_and_compile_cnn_model(): | ||
model = tf.keras.Sequential([ | ||
tf.keras.layers.InputLayer(input_shape=(28, 28)), | ||
tf.keras.layers.Reshape(target_shape=(28, 28, 1)), | ||
tf.keras.layers.Conv2D(32, 3, activation='relu'), | ||
tf.keras.layers.Flatten(), | ||
tf.keras.layers.Dense(128, activation='relu'), | ||
tf.keras.layers.Dense(10) | ||
]) | ||
model.compile( | ||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), | ||
metrics=['accuracy']) | ||
return model | ||
|
||
|
||
per_worker_batch_size = 64 | ||
tf_config = json.loads(os.environ['TF_CONFIG']) | ||
num_workers = len(tf_config['cluster']['worker']) | ||
|
||
strategy = tf.distribute.MultiWorkerMirroredStrategy() | ||
|
||
global_batch_size = per_worker_batch_size * num_workers | ||
multi_worker_dataset = mnist_dataset(global_batch_size) | ||
|
||
with strategy.scope(): | ||
# Model building/compiling need to be within `strategy.scope()`. | ||
multi_worker_model = build_and_compile_cnn_model() | ||
|
||
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters