Skip to content

Commit

Permalink
feature: Adding support for Multi Worker Mirrored Strategy in TF esti…
Browse files Browse the repository at this point in the history
…mator (aws#3192)

Co-authored-by: Miyoung <myoung8739@gmail.com>
  • Loading branch information
2 people authored and Namrata Madan committed Apr 18, 2023
1 parent 28d47eb commit ea652b7
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 298 deletions.
66 changes: 43 additions & 23 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import uuid
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Union, Optional, List
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from six import string_types, with_metaclass
from six.moves.urllib.parse import urlparse
Expand Down Expand Up @@ -83,10 +85,7 @@
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import (
PipelineSession,
runnable_by_pipeline,
)
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline

logger = logging.getLogger(__name__)

Expand All @@ -106,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
LAUNCH_MWMS_ENV_NAME = "sagemaker_multi_worker_mirrored_strategy_enabled"
INSTANCE_TYPE = "sagemaker_instance_type"
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
Expand Down Expand Up @@ -557,9 +557,7 @@ def __init__(
self.dependencies = dependencies or []
self.uploaded_code = None
self.tags = add_jumpstart_tags(
tags=tags,
training_model_uri=self.model_uri,
training_script_uri=self.source_dir,
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
)
if self.instance_type in ("local", "local_gpu"):
if self.instance_type == "local_gpu" and self.instance_count > 1:
Expand Down Expand Up @@ -680,8 +678,7 @@ def _ensure_base_job_name(self):
self.base_job_name
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
or base_name_from_image(
self.training_image_uri(),
default_base_name=EstimatorBase.JOB_CLASS_NAME,
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
)
)

Expand Down Expand Up @@ -744,7 +741,6 @@ def _prepare_for_training(self, job_name=None):
self.dependencies = updated_paths["dependencies"]

if self.source_dir or self.entry_point or self.dependencies:

# validate source dir will raise a ValueError if there is something wrong with
# the source directory. We are intentionally not handling it because this is a
# critical error.
Expand Down Expand Up @@ -1023,10 +1019,7 @@ def _set_source_s3_uri(self, rule):
parse_result = urlparse(rule.rule_parameters["source_s3_uri"])
if parse_result.scheme != "s3":
desired_s3_uri = os.path.join(
"s3://",
self.sagemaker_session.default_bucket(),
rule.name,
str(uuid.uuid4()),
"s3://", self.sagemaker_session.default_bucket(), rule.name, str(uuid.uuid4())
)
s3_uri = S3Uploader.upload(
local_path=rule.rule_parameters["source_s3_uri"],
Expand Down Expand Up @@ -1439,10 +1432,7 @@ def deploy(
self._ensure_base_job_name()

jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
kwargs.get("source_dir"),
self.source_dir,
kwargs.get("model_data"),
self.model_uri,
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
)
default_name = (
name_from_base(jumpstart_base_name)
Expand Down Expand Up @@ -2240,11 +2230,7 @@ def _is_local_channel(cls, input_uri):

@classmethod
def update(
cls,
estimator,
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
):
"""Update a running Amazon SageMaker training job.
Expand Down Expand Up @@ -3165,6 +3151,34 @@ def _validate_and_set_debugger_configs(self):
)
self.debugger_hook_config = False

def _validate_mwms_config(self, distribution):
"""Validate Multi Worker Mirrored Strategy configuration."""
minimum_supported_framework_version = {"tensorflow": {"framework_version": "2.9"}}
if self._framework_name in minimum_supported_framework_version:
for version_argument in minimum_supported_framework_version[self._framework_name]:
current = getattr(self, version_argument)
threshold = minimum_supported_framework_version[self._framework_name][
version_argument
]
if Version(current) in SpecifierSet(f"< {threshold}"):
raise ValueError(
"Multi Worker Mirrored Strategy is only supported "
"from {} {} but received {}".format(version_argument, threshold, current)
)
else:
raise ValueError(
"Multi Worker Mirrored Strategy is currently only supported "
"with {} frameworks but received {}".format(
minimum_supported_framework_version.keys(), self._framework_name
)
)
unsupported_distributions = ["smdistributed", "parameter_server"]
if any(i in distribution for i in unsupported_distributions):
raise ValueError(
"Multi Worker Mirrored Strategy is currently not supported with the"
" following distribution strategies: {}".format(unsupported_distributions)
)

def _model_source_dir(self):
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
Expand Down Expand Up @@ -3528,6 +3542,12 @@ def _distribution_configuration(self, distribution):
"dataparallel"
].get("custom_mpi_options", "")

if "multi_worker_mirrored_strategy" in distribution:
mwms_enabled = distribution.get("multi_worker_mirrored_strategy").get("enabled", False)
if mwms_enabled:
self._validate_mwms_config(distribution)
distribution_config[self.LAUNCH_MWMS_ENV_NAME] = mwms_enabled

if not (mpi_enabled or smdataparallel_enabled) and distribution_config.get(
"sagemaker_distribution_instance_groups"
) not in [None, []]:
Expand Down
17 changes: 17 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,23 @@ def __init__(
To find a complete list of parameters for SageMaker model parallelism,
see :ref:`sm-sdk-modelparallel-general`.
**To enable Multi Worker Mirrored Strategy:**
.. code:: python
{
"multi_worker_mirrored_strategy": {
"enabled": True
}
}
This distribution strategy option is available for TensorFlow 2.9 and later in
the SageMaker Python SDK v2.xx.yy and later.
To learn more about the mirrored strategy for TensorFlow,
see `TensorFlow Distributed Training
<https://www.tensorflow.org/guide/distributed_training>`_
in the *TensorFlow documentation*.
**To enable MPI:**
.. code:: python
Expand Down
12 changes: 11 additions & 1 deletion src/sagemaker/tensorflow/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def validate(cls, estimator):
"""Checks if SageMaker Training Compiler is configured correctly.
Args:
estimator (str): A estimator object
estimator (:class:`sagemaker.tensorflow.estimator.TensorFlow`): A estimator object
If SageMaker Training Compiler is enabled, it will validate whether
the estimator is configured to be compatible with Training Compiler.
Expand All @@ -102,3 +102,13 @@ def validate(cls, estimator):
cls.MIN_SUPPORTED_VERSION, estimator.framework_version
)
raise ValueError(error_helper_string)

if estimator.distribution and "multi_worker_mirrored_strategy" in estimator.distribution:
mwms_enabled = estimator.distribution.get("multi_worker_mirrored_strategy").get(
"enabled", False
)
if mwms_enabled:
raise ValueError(
"Multi Worker Mirrored Strategy distributed training configuration "
"is currently not compatible with SageMaker Training Compiler."
)
11 changes: 3 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ def huggingface_training_compiler_pytorch_version(
huggingface_training_compiler_version,
):
versions = _huggingface_base_fm_version(
huggingface_training_compiler_version,
"pytorch",
"huggingface_training_compiler",
huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler"
)
if not versions:
pytest.skip(
Expand All @@ -298,9 +296,7 @@ def huggingface_training_compiler_tensorflow_version(
huggingface_training_compiler_version,
):
versions = _huggingface_base_fm_version(
huggingface_training_compiler_version,
"tensorflow",
"huggingface_training_compiler",
huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler"
)
if not versions:
pytest.skip(
Expand Down Expand Up @@ -526,8 +522,7 @@ def pytorch_ddp_py_version():


@pytest.fixture(
scope="module",
params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"],
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
)
def pytorch_ddp_framework_version(request):
return request.param
Expand Down
57 changes: 57 additions & 0 deletions tests/data/tensorflow_mnist/mnist_mwms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras

import json
import os
import tensorflow as tf
import numpy as np


def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the [0, 255] range.
# You need to convert them to float32 with values in the [0, 1] range.
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(60000)
.repeat()
.batch(batch_size)
)
return train_dataset


def build_and_compile_cnn_model():
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=["accuracy"],
)
return model


per_worker_batch_size = 64
tf_config = json.loads(os.environ["TF_CONFIG"])
num_workers = len(tf_config["cluster"]["worker"])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
multi_worker_model = build_and_compile_cnn_model()

multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

print(f"strategy.num_replicas_in_sync={strategy.num_replicas_in_sync}")
65 changes: 48 additions & 17 deletions tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SCRIPT = "mnist.py"
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
MWMS_DISTRIBUTION = {"multi_worker_mirrored_strategy": {"enabled": True}}
TAGS = [{"Key": "some-key", "Value": "some-value"}]
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}

Expand Down Expand Up @@ -68,12 +69,7 @@ def test_framework_processing_job_with_deps(
sagemaker_session=sagemaker_session,
base_job_name="test-tensorflow",
)
processor.run(
code=entry_point,
source_dir=code_path,
inputs=[],
wait=True,
)
processor.run(code=entry_point, source_dir=code_path, inputs=[], wait=True)


def test_mnist_with_checkpoint_config(
Expand Down Expand Up @@ -110,9 +106,7 @@ def test_mnist_with_checkpoint_config(
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=inputs, job_name=training_job_name)
assert_s3_file_patterns_exist(
sagemaker_session,
estimator.model_dir,
[r"model\.ckpt-\d+\.index", r"checkpoint"],
sagemaker_session, estimator.model_dir, [r"model\.ckpt-\d+\.index", r"checkpoint"]
)
# remove dataframe assertion to unblock PR build
# TODO: add independent integration test for `training_job_analytics`
Expand All @@ -130,9 +124,7 @@ def test_mnist_with_checkpoint_config(
]
)

expected_retry_strategy = {
"MaximumRetryAttempts": 2,
}
expected_retry_strategy = {"MaximumRetryAttempts": 2}
actual_retry_strategy = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=training_job_name
)["RetryStrategy"]
Expand Down Expand Up @@ -181,6 +173,48 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
)


@pytest.mark.release
@pytest.mark.skipif(
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
reason="no ml.p2 or ml.p3 instances in this region",
)
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
def test_mwms_gpu(
sagemaker_session,
tensorflow_training_latest_version,
tensorflow_training_latest_py_version,
capsys,
**kwargs,
):
instance_count = 2
estimator = TensorFlow(
source_dir=os.path.join(RESOURCE_PATH, "tensorflow_mnist"),
entry_point="mnist_mwms.py",
model_dir=False,
instance_type=kwargs["instance_type"],
instance_count=instance_count,
framework_version=tensorflow_training_latest_version,
py_version=tensorflow_training_latest_py_version,
distribution=MWMS_DISTRIBUTION,
environment={"NCCL_DEBUG": "INFO"},
max_run=60 * 60 * 1, # 1 hour
role=ROLE,
volume_size=400,
sagemaker_session=sagemaker_session,
disable_profiler=True,
)

with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(job_name=unique_name_from_base("test-tf-mwms"))

captured = capsys.readouterr()
logs = captured.out + captured.err
print(logs)
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
assert f"strategy.num_replicas_in_sync={instance_count}" in logs


@pytest.mark.release
def test_mnist_distributed_cpu(
sagemaker_session,
Expand Down Expand Up @@ -237,9 +271,7 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed"))
assert_s3_file_patterns_exist(
sagemaker_session,
estimator.model_dir,
[r"model\.ckpt-\d+\.index", r"checkpoint"],
sagemaker_session, estimator.model_dir, [r"model\.ckpt-\d+\.index", r"checkpoint"]
)


Expand Down Expand Up @@ -346,8 +378,7 @@ def test_model_deploy_with_serverless_inference_config(
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(
serverless_inference_config=ServerlessInferenceConfig(),
endpoint_name=endpoint_name,
serverless_inference_config=ServerlessInferenceConfig(), endpoint_name=endpoint_name
)

input_data = {"instances": [1.0, 2.0, 5.0]}
Expand Down
Loading

0 comments on commit ea652b7

Please sign in to comment.