Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make DistributedConfig Extensible #5039

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 70 additions & 13 deletions src/sagemaker/modules/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
"""Distributed module."""
from __future__ import absolute_import

import os

from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List
from pydantic import BaseModel, PrivateAttr
from pydantic import BaseModel
from sagemaker.modules.utils import safe_serialize
from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH


class SMP(BaseModel):
Expand Down Expand Up @@ -72,16 +76,37 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
return hyperparameters


class DistributedConfig(BaseModel):
"""Base class for distributed training configurations."""
class DistributedConfig(BaseModel, ABC):
"""Abstract base class for distributed training configurations.

This class defines the interface that all distributed training configurations
must implement. It provides a standardized way to specify driver scripts and
their locations for distributed training jobs.
"""

@property
@abstractmethod
def driver_dir(self) -> str:
"""Directory containing the driver script.

This property should return the path to the directory containing
the driver script, relative to the container's working directory.

_type: str = PrivateAttr()
Returns:
str: Path to directory containing the driver script
"""

def model_dump(self, *args, **kwargs):
"""Dump the model to a dictionary."""
result = super().model_dump(*args, **kwargs)
result["_type"] = self._type
return result
@property
@abstractmethod
def driver_script(self) -> str:
"""Name of the driver script.

This property should return the name of the Python script that implements
the distributed training driver logic.

Returns:
str: Name of the driver script file
"""


class Torchrun(DistributedConfig):
Expand All @@ -98,11 +123,27 @@ class Torchrun(DistributedConfig):
The SageMaker Model Parallelism v2 parameters.
"""

_type: str = PrivateAttr(default="torchrun")

process_count_per_node: Optional[int] = None
smp: Optional["SMP"] = None

@property
def driver_dir(self) -> str:
"""Directory containing the driver script.

Returns:
str: Path to directory containing the driver script
"""
return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers")

@property
def driver_script(self) -> str:
"""Name of the driver script.

Returns:
str: Name of the driver script file
"""
return "torchrun_driver.py"


class MPI(DistributedConfig):
"""MPI.
Expand All @@ -118,7 +159,23 @@ class MPI(DistributedConfig):
The custom MPI options to use for the training job.
"""

_type: str = PrivateAttr(default="mpi")

process_count_per_node: Optional[int] = None
mpi_additional_options: Optional[List[str]] = None

@property
def driver_dir(self) -> str:
"""Directory containing the driver script.

Returns:
str: Path to directory containing the driver script
"""
return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers")

@property
def driver_script(self) -> str:
"""Name of the driver script.

Returns:
str: Name of the driver script
"""
return "mpi_driver.py"
13 changes: 4 additions & 9 deletions src/sagemaker/modules/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,12 @@

EXECUTE_BASIC_SCRIPT_DRIVER = """
echo "Running Basic Script driver"
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/basic_script_driver.py
"""

EXEUCTE_TORCHRUN_DRIVER = """
echo "Running Torchrun driver"
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py
"""

EXECUTE_MPI_DRIVER = """
echo "Running MPI driver"
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py
EXEUCTE_DISTRIBUTED_DRIVER = """
echo "Running {driver_name} Driver"
$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/{driver_script}
"""

TRAIN_SCRIPT_TEMPLATE = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Sagemaker modules container_drivers directory."""
"""Sagemaker modules container drivers directory."""
from __future__ import absolute_import
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Sagemaker modules container drivers - common directory."""
from __future__ import absolute_import
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME
return hyperparameters_dict


def get_process_count(distributed_dict: Dict[str, Any]) -> int:
def get_process_count(process_count: Optional[int] = None) -> int:
"""Get the number of processes to run on each node in the training job."""
return (
int(distributed_dict.get("process_count_per_node", 0))
process_count
or int(os.environ.get("SM_NUM_GPUS", 0))
or int(os.environ.get("SM_NUM_NEURONS", 0))
or 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Sagemaker modules container drivers - drivers directory."""
from __future__ import absolute_import
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@
"""This module is the entry point for the Basic Script Driver."""
from __future__ import absolute_import

import os
import sys
import json
import shlex

from pathlib import Path
from typing import List

from utils import (
sys.path.insert(0, str(Path(__file__).parent.parent))

from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
logger,
get_python_executable,
read_source_code_json,
read_hyperparameters_json,
execute_commands,
write_failure_file,
hyperparameters_to_cli_args,
Expand All @@ -31,11 +34,10 @@

def create_commands() -> List[str]:
"""Create the commands to execute."""
source_code = read_source_code_json()
hyperparameters = read_hyperparameters_json()
entry_script = os.environ["SM_ENTRY_SCRIPT"]
hyperparameters = json.loads(os.environ["SM_HPS"])
python_executable = get_python_executable()

entry_script = source_code["entry_script"]
args = hyperparameters_to_cli_args(hyperparameters)
if entry_script.endswith(".py"):
commands = [python_executable, entry_script]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,8 @@
import os
import sys
import json
from pathlib import Path

from utils import (
logger,
read_source_code_json,
read_distributed_json,
read_hyperparameters_json,
hyperparameters_to_cli_args,
get_process_count,
execute_commands,
write_failure_file,
USER_CODE_PATH,
)
from mpi_utils import (
start_sshd_daemon,
bootstrap_master_node,
Expand All @@ -38,6 +28,16 @@
)


sys.path.insert(0, str(Path(__file__).parent.parent))
from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
logger,
hyperparameters_to_cli_args,
get_process_count,
execute_commands,
write_failure_file,
)


def main():
"""Main function for the MPI driver script.

Expand All @@ -58,9 +58,9 @@ def main():
5. Exit

"""
source_code = read_source_code_json()
distribution = read_distributed_json()
hyperparameters = read_hyperparameters_json()
entry_script = os.environ["SM_ENTRY_SCRIPT"]
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
hyperparameters = json.loads(os.environ["SM_HPS"])

sm_current_host = os.environ["SM_CURRENT_HOST"]
sm_hosts = json.loads(os.environ["SM_HOSTS"])
Expand All @@ -77,7 +77,8 @@ def main():

host_list = json.loads(os.environ["SM_HOSTS"])
host_count = int(os.environ["SM_HOST_COUNT"])
process_count = get_process_count(distribution)
process_count = int(distributed_config["process_count_per_node"] or 0)
process_count = get_process_count(process_count)

if process_count > 1:
host_list = ["{}:{}".format(host, process_count) for host in host_list]
Expand All @@ -86,8 +87,8 @@ def main():
host_count=host_count,
host_list=host_list,
num_processes=process_count,
additional_options=distribution.get("mpi_additional_options", []),
entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]),
additional_options=distributed_config["mpi_additional_options"] or [],
entry_script_path=entry_script,
)

args = hyperparameters_to_cli_args(hyperparameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,23 @@
from __future__ import absolute_import

import os
import sys
import subprocess
import time

from pathlib import Path
from typing import List

import paramiko
from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger

sys.path.insert(0, str(Path(__file__).parent.parent))

from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
SM_EFA_NCCL_INSTANCES,
SM_EFA_RDMA_INSTANCES,
get_python_executable,
logger,
)

FINISHED_STATUS_FILE = "/tmp/done.algo-1"
READY_FILE = "/tmp/ready.%s"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@

import os
import sys
import json

from pathlib import Path
from typing import List, Tuple

from utils import (
sys.path.insert(0, str(Path(__file__).parent.parent))

from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611
logger,
read_source_code_json,
read_distributed_json,
read_hyperparameters_json,
hyperparameters_to_cli_args,
get_process_count,
get_python_executable,
execute_commands,
write_failure_file,
USER_CODE_PATH,
SM_EFA_NCCL_INSTANCES,
SM_EFA_RDMA_INSTANCES,
)
Expand Down Expand Up @@ -65,11 +65,12 @@ def setup_env():

def create_commands():
"""Create the Torch Distributed command to execute"""
source_code = read_source_code_json()
distribution = read_distributed_json()
hyperparameters = read_hyperparameters_json()
entry_script = os.environ["SM_ENTRY_SCRIPT"]
distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"])
hyperparameters = json.loads(os.environ["SM_HPS"])

process_count = get_process_count(distribution)
process_count = int(distributed_config["process_count_per_node"] or 0)
process_count = get_process_count(process_count)
host_count = int(os.environ["SM_HOST_COUNT"])

torch_cmd = []
Expand All @@ -94,7 +95,7 @@ def create_commands():
]
)

torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])])
torch_cmd.extend([entry_script])

args = hyperparameters_to_cli_args(hyperparameters)
torch_cmd += args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Sagemaker modules scripts directory."""
"""Sagemaker modules container drivers - scripts directory."""
from __future__ import absolute_import
Loading