Skip to content

Commit

Permalink
Feature: Cluster setup for MultiWorkerMirroredStrategy (#415)
Browse files Browse the repository at this point in the history
* Feature: Cluster setup for MultiWorkerMirroredStrategy

* Configuring tests to use the new hyperparameter for MWMS

* Black formatted files

* fixing failing tests

* Removing references to py versions older than py37

* Converting py36 tests to py37

* fix: linting and changed variable name to sagemaker_multi_worker_mirrored_strategy_enabled

* fix: feezing protobuf version

* fix: renaming MWMS variable name

* fix: rename functions for _mwm to _mwms

* Revert "fix: feezing protobuf version"

This reverts commit c3e6819.

* Revert "Converting py36 tests to py37"

This reverts commit 86701b4.

* Revert "Removing references to py versions older than py37"

This reverts commit 718e5c7.

* fix: variable name changes for MWMS

* fix: renaming training script to train_dummy.py

* fix: freezing latest sagemaker toolkit version

* trigger ci

* fix: adding epochs and steps to failing MWMS test

* fix: changing MWMS testcase

* fix: logic error in MWMS

* fix: logic error in MWMS

* fix: Updating MWMS tests to check for log lines

* fix: linting

* trigger ci

* trigger ci

Co-authored-by: Nishanth Hegde <hegdn@amazon.com>
  • Loading branch information
Lokiiiiii and nish21 authored Jun 4, 2022
1 parent a58d124 commit 1d4b916
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 22 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def read_version():
"Programming Language :: Python :: 3.9",
],
install_requires=[
"sagemaker-training>=4.1.0",
"sagemaker-training>=4.1.3",
"numpy",
"scipy",
"sklearn",
Expand Down
61 changes: 53 additions & 8 deletions src/sagemaker_tensorflow_container/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@

SAGEMAKER_PARAMETER_SERVER_ENABLED = "sagemaker_parameter_server_enabled"
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED = "sagemaker_distributed_dataparallel_enabled"
SAGEMAKER_MULTI_WORKER_MIRRORED_STRATEGY_ENABLED = (
"sagemaker_multi_worker_mirrored_strategy_enabled"
)
MODEL_DIR = "/opt/ml/model"


def _is_host_master(hosts, current_host):
return current_host == hosts[0]


def _build_tf_config(hosts, current_host, ps_task=False):
def _build_tf_config_for_ps(hosts, current_host, ps_task=False):
"""Builds a dictionary containing cluster information based on number of hosts and number of
parameter servers.
Expand Down Expand Up @@ -85,6 +88,31 @@ def host_addresses(hosts, port=2222):
return tf_config


def _build_tf_config_for_mwms(hosts, current_host):
"""Builds a dictionary containing cluster information based on number of workers
for Multi Worker Mirrored distribution strategy.
Args:
hosts (list[str]): List of host names in the cluster
current_host (str): Current host name
Returns:
dict[str: dict]: A dictionary describing the cluster setup for distributed training.
For more information regarding TF_CONFIG:
https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details
"""
workers = hosts

def host_addresses(hosts, port=8890):
return ["{}:{}".format(host, port) for host in hosts]

tf_config = {"cluster": {}, "environment": "cloud"}
tf_config["cluster"]["worker"] = host_addresses(workers)
tf_config["task"] = {"index": workers.index(current_host), "type": "worker"}

return tf_config


def _run_ps(env, cluster):
logger.info("Running distributed training job with parameter servers")

Expand Down Expand Up @@ -134,17 +162,35 @@ def train(env, cmd_args):
Args:
env (sagemaker_training.environment.Environment): Instance of Environment class
"""
parameter_server_enabled = env.additional_framework_parameters.get(
SAGEMAKER_PARAMETER_SERVER_ENABLED, False
parameter_server_enabled = (
env.additional_framework_parameters.get(SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
and len(env.hosts) > 1
)
multi_worker_mirrored_strategy_enabled = env.additional_framework_parameters.get(
SAGEMAKER_MULTI_WORKER_MIRRORED_STRATEGY_ENABLED, False
)
sagemaker_distributed_dataparallel_enabled = env.additional_framework_parameters.get(
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED, False
)
if len(env.hosts) > 1 and parameter_server_enabled:

tf_config = _build_tf_config(hosts=env.hosts, current_host=env.current_host)
env_vars = env.to_env_vars()

# Setup
if parameter_server_enabled:

tf_config = _build_tf_config_for_ps(hosts=env.hosts, current_host=env.current_host)
logger.info("Running distributed training job with parameter servers")

elif multi_worker_mirrored_strategy_enabled:

env_vars["TF_CONFIG"] = json.dumps(
_build_tf_config_for_mwms(hosts=env.hosts, current_host=env.current_host)
)
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")

# Run
if parameter_server_enabled:

logger.info("Launching parameter server process")
_run_ps(env, tf_config["cluster"])
logger.info("Launching worker process")
Expand All @@ -168,7 +214,7 @@ def train(env, cmd_args):
uri=env.module_dir,
user_entry_point=env.user_entry_point,
args=cmd_args,
env_vars=env.to_env_vars(),
env_vars=env_vars,
capture_error=True,
runner_type=runner_type,
)
Expand Down Expand Up @@ -217,8 +263,7 @@ def _model_dir_with_training_job(model_dir, job_name):


def main():
"""Training entry point
"""
"""Training entry point"""
hyperparameters = environment.read_hyperparameters()
env = environment.Environment(hyperparameters=hyperparameters)

Expand Down
42 changes: 42 additions & 0 deletions test/integration/sagemaker/test_multi_worker_mirrored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2017-2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os

from sagemaker.tensorflow import TensorFlow
from sagemaker.utils import unique_name_from_base


RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "resources")


def test_multi_node(sagemaker_session, instance_type, image_uri, tmpdir, framework_version, capsys):
estimator = TensorFlow(
entry_point=os.path.join(RESOURCE_PATH, "multi_worker_mirrored", "train_dummy.py"),
role="SageMakerRole",
instance_type=instance_type,
instance_count=2,
image_name=image_uri,
framework_version=framework_version,
py_version="py3",
hyperparameters={
"sagemaker_multi_worker_mirrored_strategy_enabled": True,
},
sagemaker_session=sagemaker_session,
)
estimator.fit(job_name=unique_name_from_base("test-tf-mwms"))
captured = capsys.readouterr()
logs = captured.out + captured.err
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
assert "TF_CONFIG=" in logs
13 changes: 13 additions & 0 deletions test/resources/multi_worker_mirrored/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import absolute_import
48 changes: 48 additions & 0 deletions test/resources/multi_worker_mirrored/train_dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Please refer to https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_keras.ipynb

import tensorflow as tf
import numpy as np
import os
import json


def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the [0, 255] range.
# You need to convert them to float32 with values in the [0, 1] range.
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset

def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model


per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_and_compile_cnn_model()

multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)
56 changes: 43 additions & 13 deletions test/unit/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
"worker": ["{}:2222".format(HOST2)],
"ps": ["{}:2223".format(HOST1), "{}:2223".format(HOST2)],
}
CLUSTER_WITH_MWMS = {"worker": ["{}:8890".format(HOST) for HOST in HOST_LIST]}

MASTER_TASK = {"index": 0, "type": "master"}
WORKER_TASK = {"index": 0, "type": "worker"}
PS_TASK_1 = {"index": 0, "type": "ps"}
Expand Down Expand Up @@ -109,7 +111,9 @@ def test_train_horovod(run_module, single_machine_training_env):

@patch("sagemaker_training.entry_point.run")
def test_train_smdataparallel(run_module, single_machine_training_env):
single_machine_training_env.additional_framework_parameters["sagemaker_distributed_dataparallel_enabled"] = True
single_machine_training_env.additional_framework_parameters[
"sagemaker_distributed_dataparallel_enabled"
] = True

training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
run_module.assert_called_with(
Expand All @@ -124,7 +128,8 @@ def test_train_smdataparallel(run_module, single_machine_training_env):

@pytest.mark.skip_on_pipeline
@pytest.mark.skipif(
sys.version_info.major != 3, reason="Skip this for python 2 because of dict key order mismatch"
sys.version_info.major != 3,
reason="Skip this for python 2 because of dict key order mismatch",
)
@patch("tensorflow.train.ClusterSpec")
@patch("tensorflow.distribute.Server")
Expand All @@ -135,7 +140,11 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)

cluster_spec.assert_called_with(
{"worker": ["host2:2222"], "master": ["host1:2222"], "ps": ["host1:2223", "host2:2223"]}
{
"worker": ["host2:2222"],
"master": ["host1:2222"],
"ps": ["host1:2223", "host2:2223"],
}
)

tf_server.assert_called_with(
Expand Down Expand Up @@ -166,7 +175,8 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai

@pytest.mark.skip_on_pipeline
@pytest.mark.skipif(
sys.version_info.major != 3, reason="Skip this for python 2 because of dict key order mismatch"
sys.version_info.major != 3,
reason="Skip this for python 2 because of dict key order mismatch",
)
@patch("tensorflow.train.ClusterSpec")
@patch("tensorflow.distribute.Server")
Expand All @@ -179,7 +189,11 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)

cluster_spec.assert_called_with(
{"worker": ["host2:2222"], "master": ["host1:2222"], "ps": ["host1:2223", "host2:2223"]}
{
"worker": ["host2:2222"],
"master": ["host1:2222"],
"ps": ["host1:2223", "host2:2223"],
}
)

tf_server.assert_called_with(
Expand Down Expand Up @@ -226,32 +240,45 @@ def test_train_distributed_no_ps(run, distributed_training_env):
)


def test_build_tf_config():
assert training._build_tf_config(HOST_LIST, HOST1) == {
def test_build_tf_config_for_mwms():
assert training._build_tf_config_for_mwms(HOST_LIST, HOST1) == {
"cluster": CLUSTER_WITH_MWMS,
"environment": "cloud",
"task": {"index": HOST_LIST.index(HOST1), "type": "worker"},
}
assert training._build_tf_config_for_mwms(HOST_LIST, HOST2) == {
"cluster": CLUSTER_WITH_MWMS,
"environment": "cloud",
"task": {"index": HOST_LIST.index(HOST2), "type": "worker"},
}


def test_build_tf_config_for_ps():
assert training._build_tf_config_for_ps(HOST_LIST, HOST1) == {
"cluster": CLUSTER_WITH_PS,
"environment": "cloud",
"task": MASTER_TASK,
}
assert training._build_tf_config(HOST_LIST, HOST1, ps_task=True) == {
assert training._build_tf_config_for_ps(HOST_LIST, HOST1, ps_task=True) == {
"cluster": CLUSTER_WITH_PS,
"environment": "cloud",
"task": PS_TASK_1,
}
assert training._build_tf_config(HOST_LIST, HOST2) == {
assert training._build_tf_config_for_ps(HOST_LIST, HOST2) == {
"cluster": CLUSTER_WITH_PS,
"environment": "cloud",
"task": WORKER_TASK,
}
assert training._build_tf_config(HOST_LIST, HOST2, ps_task=True) == {
assert training._build_tf_config_for_ps(HOST_LIST, HOST2, ps_task=True) == {
"cluster": CLUSTER_WITH_PS,
"environment": "cloud",
"task": PS_TASK_2,
}


def test_build_tf_config_error():
def test_build_tf_config_for_ps_error():
with pytest.raises(ValueError) as error:
training._build_tf_config([HOST1], HOST1, ps_task=True)
training._build_tf_config_for_ps([HOST1], HOST1, ps_task=True)
assert "Cannot have a ps task if there are no parameter servers in the cluster" in str(
error.value
)
Expand Down Expand Up @@ -327,7 +354,10 @@ def test_main(
@patch("sagemaker_tensorflow_container.training.train")
@patch("logging.Logger.setLevel")
@patch("sagemaker_training.environment.Environment")
@patch("sagemaker_training.environment.read_hyperparameters", return_value={"model_dir": MODEL_DIR})
@patch(
"sagemaker_training.environment.read_hyperparameters",
return_value={"model_dir": MODEL_DIR},
)
@patch("sagemaker_tensorflow_container.s3_utils.configure")
def test_main_simple_training_model_dir(
configure_s3_env,
Expand Down

0 comments on commit 1d4b916

Please sign in to comment.