Skip to content

Commit

Permalink
test: adding integration tests targetting MWMS in TF
Browse files Browse the repository at this point in the history
  • Loading branch information
Lokiiiiii committed Jun 22, 2022
1 parent c77f874 commit 0c15bc7
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,24 @@ def inf_instance_family(inf_instance_type):
return "_".join(inf_instance_type.split(".")[0:2])


@pytest.fixture(scope="session")
def imagenet_train_subset(request, sagemaker_session, tmpdir_factory):
"""
Copies the Imagenet dataset from the bucket it's hosted in to the local bucket in the test region
"""
local_path = tmpdir_factory.mktemp("imagenet_tfrecords_train_subset")
sagemaker_session.download_data(
path=local_path,
bucket="collection-of-ml-datasets",
key_prefix="Imagenet/TFRecords/train_1_of_10",
)
train_input = sagemaker_session.upload_data(
path=local_path,
key_prefix="integ-test-data/imagenet/TFRecords/train",
)
return train_input


def pytest_generate_tests(metafunc):
if "instance_type" in metafunc.fixturenames:
boto_config = metafunc.config.getoption("--boto-config")
Expand Down
74 changes: 74 additions & 0 deletions tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SCRIPT = "mnist.py"
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
MWMS_DISTRIBUTION = {"multi_worker_mirrored_strategy": {"enabled": True}}
TAGS = [{"Key": "some-key", "Value": "some-value"}]
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}

Expand Down Expand Up @@ -181,6 +182,79 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v
)


@pytest.mark.release
@pytest.mark.skipif(
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
reason="no ml.p2 or ml.p3 instances in this region",
)
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
def test_mwms_gpu(
sagemaker_session,
tensorflow_training_latest_version,
tensorflow_training_latest_py_version,
capsys,
imagenet_train_subset,
**kwargs,
):
epochs = 1
global_batch_size = 64
train_steps = int(10**4 * epochs / global_batch_size)
steps_per_loop = train_steps // 10
overrides = (
f"runtime.enable_xla=False,"
f"runtime.num_gpus=1,"
f"runtime.distribution_strategy=multi_worker_mirrored,"
f"runtime.mixed_precision_dtype=float16,"
f"task.train_data.global_batch_size={global_batch_size},"
f"task.train_data.input_path=/opt/ml/input/data/training/train-000*,"
f"task.train_data.cache=True,"
f"trainer.train_steps={train_steps},"
f"trainer.steps_per_loop={steps_per_loop},"
f"trainer.summary_interval={steps_per_loop},"
f"trainer.checkpoint_interval={train_steps},"
f"task.model.backbone.type=resnet,"
f"task.model.backbone.resnet.model_id=50"
)
estimator = TensorFlow(
git_config={
"repo": "https://github.com/tensorflow/models.git",
"branch": "v2.9.2",
},
source_dir=".",
entry_point="official/vision/train.py",
model_dir=False,
instance_type=kwargs["instance_type"],
instance_count=2,
framework_version=tensorflow_training_latest_version,
py_version=tensorflow_training_latest_py_version,
distribution=MWMS_DISTRIBUTION,
hyperparameters={
"experiment": "resnet_imagenet",
"config_file": "official/vision/configs/experiments/image_classification/imagenet_resnet50_gpu.yaml",
"mode": "train",
"model_dir": "/opt/ml/model",
"params_override": overrides,
},
environment={
"NCCL_DEBUG": "INFO",
},
max_run=60 * 60 * 1, # 1 hour
role=ROLE,
volume_size=400,
sagemaker_session=sagemaker_session,
disable_profiler=True,
)

with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
estimator.fit(inputs=imagenet_train_subset, job_name=unique_name_from_base("test-tf-mwms"))

captured = capsys.readouterr()
logs = captured.out + captured.err
assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs
raise NotImplementedError("Check model saving")


@pytest.mark.release
def test_mnist_distributed_cpu(
sagemaker_session,
Expand Down

0 comments on commit 0c15bc7

Please sign in to comment.