Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Adding support for Multi Worker Mirrored Strategy in TF estimator #3192

Merged
merged 39 commits into from
Feb 18, 2023
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
16d7058
feature: adding support for Multi Worker Mirrored Strategy in TF esti…
Lokiiiiii Jun 22, 2022
f187cd0
tests: adding unit tests targeting MWMS in TF
Lokiiiiii Jun 22, 2022
dd6fd51
test: adding integration tests targetting MWMS in TF
Lokiiiiii Jun 22, 2022
c954c79
fix: linting and removing accidental file addition
Lokiiiiii Jun 22, 2022
cf0740f
doc: Adding doc strings to tests
Lokiiiiii Jun 22, 2022
6e3c4e2
Fixing MWMS unit test for TF2
Lokiiiiii Jun 22, 2022
6e3fa48
Fixing MWMS tests for TF2
Lokiiiiii Jun 23, 2022
4cb17fe
Fixing MWMS tests for TF2
Lokiiiiii Jun 23, 2022
aecd12c
Fixing MWMS tests for TF2
Lokiiiiii Jun 23, 2022
ab2ad0d
Fixing MWMS tests for TF2
Lokiiiiii Jun 23, 2022
f7d9f6c
Fixing MWMS tests for TF2
Lokiiiiii Jun 23, 2022
3ba2c39
Finishing up MWMS tests
Lokiiiiii Jun 24, 2022
55af827
Finishing up MWMS tests
Lokiiiiii Jun 24, 2022
b87df13
Using entire imagenet dataset instead of a subset
Lokiiiiii Jun 30, 2022
2f99064
update: pruning unused fixtures in MWMS test
Lokiiiiii Jul 1, 2022
6599d8d
fix: stop saving artifacts from MWMS test
Lokiiiiii Jul 5, 2022
96a3224
fix: save artifacts from MWMS test to tmp directory to discard
Lokiiiiii Jul 5, 2022
838f66b
Adding a new test for HF transformers
Lokiiiiii Dec 1, 2022
a36fa7a
Removing stale tests and fixtures
Lokiiiiii Dec 1, 2022
ddbb4ea
fix: fixing syntax error in MWMS test
Lokiiiiii Dec 2, 2022
30cb055
Update src/sagemaker/tensorflow/estimator.py
Lokiiiiii Dec 3, 2022
fbed9cb
Update src/sagemaker/tensorflow/estimator.py
Lokiiiiii Dec 3, 2022
7f7ad85
Fixing docstring syntax and auto-formatting
Lokiiiiii Dec 5, 2022
b988b8f
Adding more validation checks when using MWMS
Lokiiiiii Dec 5, 2022
6185edc
fixing new test targeting mwms-smdist
Lokiiiiii Dec 5, 2022
dd679c8
Adding training script that leverages MultiWorkerMirroredStrategy
Lokiiiiii Dec 8, 2022
a37842b
Fixing merge conflict
Lokiiiiii Feb 8, 2023
70e9ff3
Auto reformat with black
Lokiiiiii Feb 8, 2023
3e7e9f8
Updating HF TF example to latest from repo
Lokiiiiii Feb 9, 2023
bc9edde
Switching to a simpler test for keras examples for MWMS
Lokiiiiii Feb 9, 2023
0507792
Switching to a simpler test for keras examples for MWMS
Lokiiiiii Feb 10, 2023
32da5f1
Removing stale test training scripts
Lokiiiiii Feb 10, 2023
aac0532
black -l 100
Lokiiiiii Feb 10, 2023
cf92c0d
python3.10 -m black -l 100
Lokiiiiii Feb 10, 2023
5f4959e
Fixing unit tests for MWMS
Lokiiiiii Feb 10, 2023
2c48152
Fixing unit tests for MWMS
Lokiiiiii Feb 10, 2023
c6ebb5d
Fixing unit tests for MWMS
Lokiiiiii Feb 10, 2023
9a527b7
Removing unused fixtures
Lokiiiiii Feb 14, 2023
a838741
Merge remote-tracking branch 'aws/master' into mwms-2
Lokiiiiii Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import uuid
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Union, Optional, List
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from six import string_types, with_metaclass
from six.moves.urllib.parse import urlparse
Expand Down Expand Up @@ -83,10 +85,7 @@
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import (
PipelineSession,
runnable_by_pipeline,
)
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline

logger = logging.getLogger(__name__)

Expand All @@ -106,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled"
LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled"
LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled"
LAUNCH_MWMS_ENV_NAME = "sagemaker_multi_worker_mirrored_strategy_enabled"
INSTANCE_TYPE = "sagemaker_instance_type"
MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host"
MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options"
Expand Down Expand Up @@ -557,9 +557,7 @@ def __init__(
self.dependencies = dependencies or []
self.uploaded_code = None
self.tags = add_jumpstart_tags(
tags=tags,
training_model_uri=self.model_uri,
training_script_uri=self.source_dir,
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
)
if self.instance_type in ("local", "local_gpu"):
if self.instance_type == "local_gpu" and self.instance_count > 1:
Expand Down Expand Up @@ -680,8 +678,7 @@ def _ensure_base_job_name(self):
self.base_job_name
or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri)
or base_name_from_image(
self.training_image_uri(),
default_base_name=EstimatorBase.JOB_CLASS_NAME,
self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME
)
)

Expand Down Expand Up @@ -744,7 +741,6 @@ def _prepare_for_training(self, job_name=None):
self.dependencies = updated_paths["dependencies"]

if self.source_dir or self.entry_point or self.dependencies:

# validate source dir will raise a ValueError if there is something wrong with
# the source directory. We are intentionally not handling it because this is a
# critical error.
Expand Down Expand Up @@ -1023,10 +1019,7 @@ def _set_source_s3_uri(self, rule):
parse_result = urlparse(rule.rule_parameters["source_s3_uri"])
if parse_result.scheme != "s3":
desired_s3_uri = os.path.join(
"s3://",
self.sagemaker_session.default_bucket(),
rule.name,
str(uuid.uuid4()),
"s3://", self.sagemaker_session.default_bucket(), rule.name, str(uuid.uuid4())
)
s3_uri = S3Uploader.upload(
local_path=rule.rule_parameters["source_s3_uri"],
Expand Down Expand Up @@ -1439,10 +1432,7 @@ def deploy(
self._ensure_base_job_name()

jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
kwargs.get("source_dir"),
self.source_dir,
kwargs.get("model_data"),
self.model_uri,
kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri
)
default_name = (
name_from_base(jumpstart_base_name)
Expand Down Expand Up @@ -2240,11 +2230,7 @@ def _is_local_channel(cls, input_uri):

@classmethod
def update(
cls,
estimator,
profiler_rule_configs=None,
profiler_config=None,
resource_config=None,
cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None
):
"""Update a running Amazon SageMaker training job.

Expand Down Expand Up @@ -3165,6 +3151,34 @@ def _validate_and_set_debugger_configs(self):
)
self.debugger_hook_config = False

def _validate_mwms_config(self, distribution):
mufaddal-rohawala marked this conversation as resolved.
Show resolved Hide resolved
"""Validate Multi Worker Mirrored Strategy configuration."""
minimum_supported_framework_version = {"tensorflow": {"framework_version": "2.9"}}
if self._framework_name in minimum_supported_framework_version:
for version_argument in minimum_supported_framework_version[self._framework_name]:
current = getattr(self, version_argument)
threshold = minimum_supported_framework_version[self._framework_name][
version_argument
]
if Version(current) in SpecifierSet(f"< {threshold}"):
raise ValueError(
"Multi Worker Mirrored Strategy is only supported "
"from {} {} but received {}".format(version_argument, threshold, current)
)
else:
raise ValueError(
"Multi Worker Mirrored Strategy is currently only supported "
"with {} frameworks but received {}".format(
minimum_supported_framework_version.keys(), self._framework_name
)
)
unsupported_distributions = ["smdistributed", "parameter_server"]
if any(i in distribution for i in unsupported_distributions):
raise ValueError(
"Multi Worker Mirrored Strategy is currently not supported with the"
" following distribution strategies: {}".format(unsupported_distributions)
)

def _model_source_dir(self):
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.

Expand Down Expand Up @@ -3528,6 +3542,12 @@ def _distribution_configuration(self, distribution):
"dataparallel"
].get("custom_mpi_options", "")

if "multi_worker_mirrored_strategy" in distribution:
mwms_enabled = distribution.get("multi_worker_mirrored_strategy").get("enabled", False)
if mwms_enabled:
self._validate_mwms_config(distribution)
distribution_config[self.LAUNCH_MWMS_ENV_NAME] = mwms_enabled

if not (mpi_enabled or smdataparallel_enabled) and distribution_config.get(
"sagemaker_distribution_instance_groups"
) not in [None, []]:
Expand Down
17 changes: 17 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,23 @@ def __init__(
To find a complete list of parameters for SageMaker model parallelism,
see :ref:`sm-sdk-modelparallel-general`.

**To enable Multi Worker Mirrored Strategy:**
Lokiiiiii marked this conversation as resolved.
Show resolved Hide resolved

.. code:: python

{
"multi_worker_mirrored_strategy": {
"enabled": True
}
}

This distribution strategy option is available for TensorFlow 2.9 and later in
the SageMaker Python SDK v2.xx.yy and later.
To learn more about the mirrored strategy for TensorFlow,
see `TensorFlow Distributed Training
<https://www.tensorflow.org/guide/distributed_training>`_
in the *TensorFlow documentation*.

**To enable MPI:**

.. code:: python
Expand Down
12 changes: 11 additions & 1 deletion src/sagemaker/tensorflow/training_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def validate(cls, estimator):
"""Checks if SageMaker Training Compiler is configured correctly.

Args:
estimator (str): A estimator object
estimator (:class:`sagemaker.tensorflow.estimator.TensorFlow`): A estimator object
If SageMaker Training Compiler is enabled, it will validate whether
the estimator is configured to be compatible with Training Compiler.

Expand All @@ -102,3 +102,13 @@ def validate(cls, estimator):
cls.MIN_SUPPORTED_VERSION, estimator.framework_version
)
raise ValueError(error_helper_string)

if estimator.distribution and "multi_worker_mirrored_strategy" in estimator.distribution:
mwms_enabled = estimator.distribution.get("multi_worker_mirrored_strategy").get(
"enabled", False
)
if mwms_enabled:
raise ValueError(
"Multi Worker Mirrored Strategy distributed training configuration "
"is currently not compatible with SageMaker Training Compiler."
)
11 changes: 3 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,7 @@ def huggingface_training_compiler_pytorch_version(
huggingface_training_compiler_version,
):
versions = _huggingface_base_fm_version(
huggingface_training_compiler_version,
"pytorch",
"huggingface_training_compiler",
huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler"
)
if not versions:
pytest.skip(
Expand All @@ -298,9 +296,7 @@ def huggingface_training_compiler_tensorflow_version(
huggingface_training_compiler_version,
):
versions = _huggingface_base_fm_version(
huggingface_training_compiler_version,
"tensorflow",
"huggingface_training_compiler",
huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler"
)
if not versions:
pytest.skip(
Expand Down Expand Up @@ -516,8 +512,7 @@ def pytorch_ddp_py_version():


@pytest.fixture(
scope="module",
params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"],
scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"]
)
def pytorch_ddp_framework_version(request):
return request.param
Expand Down
57 changes: 57 additions & 0 deletions tests/data/tensorflow_mnist/mnist_mwms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras

import json
import os
import tensorflow as tf
import numpy as np


def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
# The `x` arrays are in uint8 and have values in the [0, 255] range.
# You need to convert them to float32 with values in the [0, 1] range.
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(60000)
.repeat()
.batch(batch_size)
)
return train_dataset


def build_and_compile_cnn_model():
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=["accuracy"],
)
return model


per_worker_batch_size = 64
tf_config = json.loads(os.environ["TF_CONFIG"])
num_workers = len(tf_config["cluster"]["worker"])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist_dataset(global_batch_size)

with strategy.scope():
multi_worker_model = build_and_compile_cnn_model()

multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

print(f"strategy.num_replicas_in_sync={strategy.num_replicas_in_sync}")
65 changes: 48 additions & 17 deletions tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SCRIPT = "mnist.py"
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
MWMS_DISTRIBUTION = {"multi_worker_mirrored_strategy": {"enabled": True}}
TAGS = [{"Key": "some-key", "Value": "some-value"}]
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}

Expand Down Expand Up @@ -68,12 +69,7 @@ def test_framework_processing_job_with_deps(
sagemaker_session=sagemaker_session,
base_job_name="test-tensorflow",
)
processor.run(
code=entry_point,
source_dir=code_path,
inputs=[],
wait=True,
)
processor.run(code=entry_point, source_dir=code_path, inputs=[], wait=True)


def test_mnist_with_checkpoint_config(
Expand Down Expand Up @@ -110,9 +106,7 @@ def test_mnist_with_checkpoint_config(
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=inputs, job_name=training_job_name)
assert_s3_file_patterns_exist(
sagemaker_session,
estimator.model_dir,
[r"model\.ckpt-\d+\.index", r"checkpoint"],
sagemaker_session, estimator.model_dir, [r"model\.ckpt-\d+\.index", r"checkpoint"]
)
# remove dataframe assertion to unblock PR build
# TODO: add independent integration test for `training_job_analytics`
Expand All @@ -130,9 +124,7 @@ def test_mnist_with_checkpoint_config(
]
)

expected_retry_strategy = {
"MaximumRetryAttempts": 2,
}
expected_retry_strategy = {"MaximumRetryAttempts": 2}
actual_retry_strategy = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=training_job_name
)["RetryStrategy"]
Expand Down Expand Up @@ -181,6 +173,48 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
)


@pytest.mark.release
@pytest.mark.skipif(
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
reason="no ml.p2 or ml.p3 instances in this region",
)
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
def test_mwms_gpu(
sagemaker_session,
tensorflow_training_latest_version,
tensorflow_training_latest_py_version,
capsys,
**kwargs,
):
instance_count = 2
estimator = TensorFlow(
source_dir=os.path.join(RESOURCE_PATH, "tensorflow_mnist"),
entry_point="mnist_mwms.py",
model_dir=False,
instance_type=kwargs["instance_type"],
instance_count=instance_count,
framework_version=tensorflow_training_latest_version,
py_version=tensorflow_training_latest_py_version,
distribution=MWMS_DISTRIBUTION,
environment={"NCCL_DEBUG": "INFO"},
max_run=60 * 60 * 1, # 1 hour
role=ROLE,
volume_size=400,
sagemaker_session=sagemaker_session,
disable_profiler=True,
)

with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(job_name=unique_name_from_base("test-tf-mwms"))

captured = capsys.readouterr()
logs = captured.out + captured.err
print(logs)
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
assert f"strategy.num_replicas_in_sync={instance_count}" in logs


@pytest.mark.release
def test_mnist_distributed_cpu(
sagemaker_session,
Expand Down Expand Up @@ -237,9 +271,7 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed"))
assert_s3_file_patterns_exist(
sagemaker_session,
estimator.model_dir,
[r"model\.ckpt-\d+\.index", r"checkpoint"],
sagemaker_session, estimator.model_dir, [r"model\.ckpt-\d+\.index", r"checkpoint"]
)


Expand Down Expand Up @@ -346,8 +378,7 @@ def test_model_deploy_with_serverless_inference_config(
sagemaker_session=sagemaker_session,
)
predictor = model.deploy(
serverless_inference_config=ServerlessInferenceConfig(),
endpoint_name=endpoint_name,
serverless_inference_config=ServerlessInferenceConfig(), endpoint_name=endpoint_name
)

input_data = {"instances": [1.0, 2.0, 5.0]}
Expand Down
Loading