diff --git a/src/sagemaker/modules/distributed.py b/src/sagemaker/modules/distributed.py index 6cdc136dcf..471aaa6e12 100644 --- a/src/sagemaker/modules/distributed.py +++ b/src/sagemaker/modules/distributed.py @@ -13,9 +13,13 @@ """Distributed module.""" from __future__ import absolute_import +import os + +from abc import ABC, abstractmethod from typing import Optional, Dict, Any, List -from pydantic import BaseModel, PrivateAttr +from pydantic import BaseModel from sagemaker.modules.utils import safe_serialize +from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH class SMP(BaseModel): @@ -72,16 +76,37 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]: return hyperparameters -class DistributedConfig(BaseModel): - """Base class for distributed training configurations.""" +class DistributedConfig(BaseModel, ABC): + """Abstract base class for distributed training configurations. + + This class defines the interface that all distributed training configurations + must implement. It provides a standardized way to specify driver scripts and + their locations for distributed training jobs. + """ + + @property + @abstractmethod + def driver_dir(self) -> str: + """Directory containing the driver script. + + This property should return the path to the directory containing + the driver script, relative to the container's working directory. - _type: str = PrivateAttr() + Returns: + str: Path to directory containing the driver script + """ - def model_dump(self, *args, **kwargs): - """Dump the model to a dictionary.""" - result = super().model_dump(*args, **kwargs) - result["_type"] = self._type - return result + @property + @abstractmethod + def driver_script(self) -> str: + """Name of the driver script. + + This property should return the name of the Python script that implements + the distributed training driver logic. + + Returns: + str: Name of the driver script file + """ class Torchrun(DistributedConfig): @@ -98,11 +123,27 @@ class Torchrun(DistributedConfig): The SageMaker Model Parallelism v2 parameters. """ - _type: str = PrivateAttr(default="torchrun") - process_count_per_node: Optional[int] = None smp: Optional["SMP"] = None + @property + def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ + return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers") + + @property + def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script file + """ + return "torchrun_driver.py" + class MPI(DistributedConfig): """MPI. @@ -118,7 +159,23 @@ class MPI(DistributedConfig): The custom MPI options to use for the training job. """ - _type: str = PrivateAttr(default="mpi") - process_count_per_node: Optional[int] = None mpi_additional_options: Optional[List[str]] = None + + @property + def driver_dir(self) -> str: + """Directory containing the driver script. + + Returns: + str: Path to directory containing the driver script + """ + return os.path.join(SM_DRIVERS_LOCAL_PATH, "drivers") + + @property + def driver_script(self) -> str: + """Name of the driver script. + + Returns: + str: Name of the driver script + """ + return "mpi_driver.py" diff --git a/src/sagemaker/modules/templates.py b/src/sagemaker/modules/templates.py index fba60dda47..9dfef646ed 100644 --- a/src/sagemaker/modules/templates.py +++ b/src/sagemaker/modules/templates.py @@ -21,17 +21,12 @@ EXECUTE_BASIC_SCRIPT_DRIVER = """ echo "Running Basic Script driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/basic_script_driver.py +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/basic_script_driver.py """ -EXEUCTE_TORCHRUN_DRIVER = """ -echo "Running Torchrun driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/torchrun_driver.py -""" - -EXECUTE_MPI_DRIVER = """ -echo "Running MPI driver" -$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/mpi_driver.py +EXEUCTE_DISTRIBUTED_DRIVER = """ +echo "Running {driver_name} Driver" +$SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/drivers/{driver_script} """ TRAIN_SCRIPT_TEMPLATE = """ diff --git a/src/sagemaker/modules/train/container_drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/__init__.py index 18557a2eb5..864f3663b8 100644 --- a/src/sagemaker/modules/train/container_drivers/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/__init__.py @@ -10,5 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Sagemaker modules container_drivers directory.""" +"""Sagemaker modules container drivers directory.""" from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/common/__init__.py b/src/sagemaker/modules/train/container_drivers/common/__init__.py new file mode 100644 index 0000000000..aab88c6b97 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/common/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - common directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/utils.py b/src/sagemaker/modules/train/container_drivers/common/utils.py similarity index 98% rename from src/sagemaker/modules/train/container_drivers/utils.py rename to src/sagemaker/modules/train/container_drivers/common/utils.py index e939a6e0b8..c07aa1359a 100644 --- a/src/sagemaker/modules/train/container_drivers/utils.py +++ b/src/sagemaker/modules/train/container_drivers/common/utils.py @@ -99,10 +99,10 @@ def read_hyperparameters_json(hyperparameters_json: Dict[str, Any] = HYPERPARAME return hyperparameters_dict -def get_process_count(distributed_dict: Dict[str, Any]) -> int: +def get_process_count(process_count: Optional[int] = None) -> int: """Get the number of processes to run on each node in the training job.""" return ( - int(distributed_dict.get("process_count_per_node", 0)) + process_count or int(os.environ.get("SM_NUM_GPUS", 0)) or int(os.environ.get("SM_NUM_NEURONS", 0)) or 1 diff --git a/src/sagemaker/modules/train/container_drivers/drivers/__init__.py b/src/sagemaker/modules/train/container_drivers/drivers/__init__.py new file mode 100644 index 0000000000..a44e7e81a9 --- /dev/null +++ b/src/sagemaker/modules/train/container_drivers/drivers/__init__.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker modules container drivers - drivers directory.""" +from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py similarity index 88% rename from src/sagemaker/modules/train/container_drivers/basic_script_driver.py rename to src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py index cb0278bc9f..0b086a8e4f 100644 --- a/src/sagemaker/modules/train/container_drivers/basic_script_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/basic_script_driver.py @@ -13,16 +13,19 @@ """This module is the entry point for the Basic Script Driver.""" from __future__ import absolute_import +import os import sys +import json import shlex +from pathlib import Path from typing import List -from utils import ( +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, get_python_executable, - read_source_code_json, - read_hyperparameters_json, execute_commands, write_failure_file, hyperparameters_to_cli_args, @@ -31,11 +34,10 @@ def create_commands() -> List[str]: """Create the commands to execute.""" - source_code = read_source_code_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + hyperparameters = json.loads(os.environ["SM_HPS"]) python_executable = get_python_executable() - entry_script = source_code["entry_script"] args = hyperparameters_to_cli_args(hyperparameters) if entry_script.endswith(".py"): commands = [python_executable, entry_script] diff --git a/src/sagemaker/modules/train/container_drivers/mpi_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py similarity index 83% rename from src/sagemaker/modules/train/container_drivers/mpi_driver.py rename to src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py index dceb748cc0..9946272617 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/mpi_driver.py @@ -16,18 +16,8 @@ import os import sys import json +from pathlib import Path -from utils import ( - logger, - read_source_code_json, - read_distributed_json, - read_hyperparameters_json, - hyperparameters_to_cli_args, - get_process_count, - execute_commands, - write_failure_file, - USER_CODE_PATH, -) from mpi_utils import ( start_sshd_daemon, bootstrap_master_node, @@ -38,6 +28,16 @@ ) +sys.path.insert(0, str(Path(__file__).parent.parent)) +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + logger, + hyperparameters_to_cli_args, + get_process_count, + execute_commands, + write_failure_file, +) + + def main(): """Main function for the MPI driver script. @@ -58,9 +58,9 @@ def main(): 5. Exit """ - source_code = read_source_code_json() - distribution = read_distributed_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) sm_current_host = os.environ["SM_CURRENT_HOST"] sm_hosts = json.loads(os.environ["SM_HOSTS"]) @@ -77,7 +77,8 @@ def main(): host_list = json.loads(os.environ["SM_HOSTS"]) host_count = int(os.environ["SM_HOST_COUNT"]) - process_count = get_process_count(distribution) + process_count = int(distributed_config["process_count_per_node"] or 0) + process_count = get_process_count(process_count) if process_count > 1: host_list = ["{}:{}".format(host, process_count) for host in host_list] @@ -86,8 +87,8 @@ def main(): host_count=host_count, host_list=host_list, num_processes=process_count, - additional_options=distribution.get("mpi_additional_options", []), - entry_script_path=os.path.join(USER_CODE_PATH, source_code["entry_script"]), + additional_options=distributed_config["mpi_additional_options"] or [], + entry_script_path=entry_script, ) args = hyperparameters_to_cli_args(hyperparameters) diff --git a/src/sagemaker/modules/train/container_drivers/mpi_utils.py b/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py similarity index 97% rename from src/sagemaker/modules/train/container_drivers/mpi_utils.py rename to src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py index 00ddc815cd..ec9e1fcef9 100644 --- a/src/sagemaker/modules/train/container_drivers/mpi_utils.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/mpi_utils.py @@ -14,12 +14,23 @@ from __future__ import absolute_import import os +import sys import subprocess import time + +from pathlib import Path from typing import List import paramiko -from utils import SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, get_python_executable, logger + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + SM_EFA_NCCL_INSTANCES, + SM_EFA_RDMA_INSTANCES, + get_python_executable, + logger, +) FINISHED_STATUS_FILE = "/tmp/done.algo-1" READY_FILE = "/tmp/ready.%s" diff --git a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py b/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py similarity index 87% rename from src/sagemaker/modules/train/container_drivers/torchrun_driver.py rename to src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py index 666479ec84..7fcfabe05d 100644 --- a/src/sagemaker/modules/train/container_drivers/torchrun_driver.py +++ b/src/sagemaker/modules/train/container_drivers/drivers/torchrun_driver.py @@ -15,20 +15,20 @@ import os import sys +import json +from pathlib import Path from typing import List, Tuple -from utils import ( +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 logger, - read_source_code_json, - read_distributed_json, - read_hyperparameters_json, hyperparameters_to_cli_args, get_process_count, get_python_executable, execute_commands, write_failure_file, - USER_CODE_PATH, SM_EFA_NCCL_INSTANCES, SM_EFA_RDMA_INSTANCES, ) @@ -65,11 +65,12 @@ def setup_env(): def create_commands(): """Create the Torch Distributed command to execute""" - source_code = read_source_code_json() - distribution = read_distributed_json() - hyperparameters = read_hyperparameters_json() + entry_script = os.environ["SM_ENTRY_SCRIPT"] + distributed_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + hyperparameters = json.loads(os.environ["SM_HPS"]) - process_count = get_process_count(distribution) + process_count = int(distributed_config["process_count_per_node"] or 0) + process_count = get_process_count(process_count) host_count = int(os.environ["SM_HOST_COUNT"]) torch_cmd = [] @@ -94,7 +95,7 @@ def create_commands(): ] ) - torch_cmd.extend([os.path.join(USER_CODE_PATH, source_code["entry_script"])]) + torch_cmd.extend([entry_script]) args = hyperparameters_to_cli_args(hyperparameters) torch_cmd += args diff --git a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py index 1abbce4067..f04c5b17a0 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/__init__.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/__init__.py @@ -10,5 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Sagemaker modules scripts directory.""" +"""Sagemaker modules container drivers - scripts directory.""" from __future__ import absolute_import diff --git a/src/sagemaker/modules/train/container_drivers/scripts/environment.py b/src/sagemaker/modules/train/container_drivers/scripts/environment.py index ea6abac425..0ce24c55d8 100644 --- a/src/sagemaker/modules/train/container_drivers/scripts/environment.py +++ b/src/sagemaker/modules/train/container_drivers/scripts/environment.py @@ -19,12 +19,17 @@ import json import os import sys +from pathlib import Path import logging -parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -sys.path.insert(0, parent_dir) +sys.path.insert(0, str(Path(__file__).parent.parent)) -from utils import safe_serialize, safe_deserialize # noqa: E402 # pylint: disable=C0413 +from common.utils import ( # noqa: E402 # pylint: disable=C0413,E0611 + safe_serialize, + safe_deserialize, + read_distributed_json, + read_source_code_json, +) # Initialize logger SM_LOG_LEVEL = os.environ.get("SM_LOG_LEVEL", 20) @@ -42,6 +47,8 @@ SM_OUTPUT_DIR = "/opt/ml/output" SM_OUTPUT_FAILURE = "/opt/ml/output/failure" SM_OUTPUT_DATA_DIR = "/opt/ml/output/data" +SM_SOURCE_DIR_PATH = "/opt/ml/input/data/code" +SM_DRIVER_DIR_PATH = "/opt/ml/input/data/sm_drivers/drivers" SM_MASTER_ADDR = "algo-1" SM_MASTER_PORT = 7777 @@ -158,6 +165,17 @@ def set_env( "SM_MASTER_PORT": SM_MASTER_PORT, } + # SourceCode and DistributedConfig Environment Variables + source_code = read_source_code_json() + if source_code: + env_vars["SM_SOURCE_DIR"] = SM_SOURCE_DIR_PATH + env_vars["SM_ENTRY_SCRIPT"] = source_code.get("entry_script", "") + + distributed = read_distributed_json() + if distributed: + env_vars["SM_DRIVER_DIR"] = SM_DRIVER_DIR_PATH + env_vars["SM_DISTRIBUTED_CONFIG"] = distributed + # Data Channels channels = list(input_data_config.keys()) for channel in channels: diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 31decfaca9..5ab79aebdf 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -70,7 +70,7 @@ ) from sagemaker.modules.local_core.local_container import _LocalContainer -from sagemaker.modules.distributed import Torchrun, MPI, DistributedConfig +from sagemaker.modules.distributed import Torchrun, DistributedConfig from sagemaker.modules.utils import ( _get_repo_name_from_image, _get_unique_name, @@ -94,8 +94,7 @@ from sagemaker.modules.templates import ( TRAIN_SCRIPT_TEMPLATE, EXECUTE_BASE_COMMANDS, - EXECUTE_MPI_DRIVER, - EXEUCTE_TORCHRUN_DRIVER, + EXEUCTE_DISTRIBUTED_DRIVER, EXECUTE_BASIC_SCRIPT_DRIVER, ) from sagemaker.telemetry.telemetry_logging import _telemetry_emitter @@ -153,7 +152,7 @@ class ModelTrainer(BaseModel): source_code (Optional[SourceCode]): The source code configuration. This is used to configure the source code for running the training job. - distributed (Optional[Union[MPI, Torchrun]]): + distributed (Optional[Union[DistributedConfig]]): The distributed runner for the training job. This is used to configure a distributed training job. If specifed, ``source_code`` must also be provided. @@ -212,7 +211,7 @@ class ModelTrainer(BaseModel): role: Optional[str] = None base_job_name: Optional[str] = None source_code: Optional[SourceCode] = None - distributed: Optional[Union[MPI, Torchrun]] = None + distributed: Optional[Union[DistributedConfig]] = None compute: Optional[Compute] = None networking: Optional[Networking] = None stopping_condition: Optional[StoppingCondition] = None @@ -534,12 +533,17 @@ def train( container_arguments = None if self.source_code: if self.training_mode == Mode.LOCAL_CONTAINER: - drivers_dir = TemporaryDirectory( - prefix=os.path.join(self.local_container_root + "/") - ) + tmp_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/")) else: - drivers_dir = TemporaryDirectory() - shutil.copytree(SM_DRIVERS_LOCAL_PATH, drivers_dir.name, dirs_exist_ok=True) + tmp_dir = TemporaryDirectory() + # Copy everything under container_drivers/ to a temporary directory + shutil.copytree(SM_DRIVERS_LOCAL_PATH, tmp_dir.name, dirs_exist_ok=True) + + # If distributed is provided, overwrite code under /drivers + if self.distributed: + distributed_driver_dir = self.distributed.driver_dir + driver_dir = os.path.join(tmp_dir.name, "drivers") + shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True) # If source code is provided, create a channel for the source code # The source code will be mounted at /opt/ml/input/data/code in the container @@ -552,7 +556,7 @@ def train( input_data_config.append(source_code_channel) self._prepare_train_script( - tmp_dir=drivers_dir, + tmp_dir=tmp_dir, source_code=self.source_code, distributed=self.distributed, ) @@ -561,13 +565,13 @@ def train( mp_parameters = self.distributed.smp._to_mp_hyperparameters() string_hyper_parameters.update(mp_parameters) - self._write_source_code_json(tmp_dir=drivers_dir, source_code=self.source_code) - self._write_distributed_json(tmp_dir=drivers_dir, distributed=self.distributed) + self._write_source_code_json(tmp_dir=tmp_dir, source_code=self.source_code) + self._write_distributed_json(tmp_dir=tmp_dir, distributed=self.distributed) # Create an input channel for drivers packaged by the sdk sm_drivers_channel = self.create_input_data_channel( channel_name=SM_DRIVERS, - data_source=drivers_dir.name, + data_source=tmp_dir.name, key_prefix=input_data_key_prefix, ) input_data_config.append(sm_drivers_channel) @@ -769,7 +773,7 @@ def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: Sour """Write the source code configuration to a JSON file.""" file_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON) with open(file_path, "w") as f: - dump = source_code.model_dump(exclude_none=True) if source_code else {} + dump = source_code.model_dump() if source_code else {} f.write(json.dumps(dump)) def _write_distributed_json( @@ -780,7 +784,7 @@ def _write_distributed_json( """Write the distributed runner configuration to a JSON file.""" file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON) with open(file_path, "w") as f: - dump = distributed.model_dump(exclude_none=True) if distributed else {} + dump = distributed.model_dump() if distributed else {} f.write(json.dumps(dump)) def _prepare_train_script( @@ -817,13 +821,10 @@ def _prepare_train_script( if base_command: execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command) elif distributed: - distribution_type = distributed._type - if distribution_type == "mpi": - execute_driver = EXECUTE_MPI_DRIVER - elif distribution_type == "torchrun": - execute_driver = EXEUCTE_TORCHRUN_DRIVER - else: - raise ValueError(f"Unsupported distribution type: {distribution_type}.") + execute_driver = EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name=distributed.__class__.__name__, + driver_script=distributed.driver_script, + ) elif source_code.entry_script and not source_code.command and not distributed: if not source_code.entry_script.endswith((".py", ".sh")): raise ValueError( diff --git a/tests/data/modules/custom_drivers/driver.py b/tests/data/modules/custom_drivers/driver.py new file mode 100644 index 0000000000..e2a1fc7a52 --- /dev/null +++ b/tests/data/modules/custom_drivers/driver.py @@ -0,0 +1,34 @@ +import json +import os +import subprocess +import sys + + +def main(): + driver_config = json.loads(os.environ["SM_DISTRIBUTED_CONFIG"]) + process_count_per_node = driver_config["process_count_per_node"] + assert process_count_per_node != None + + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + assert isinstance(hps, dict) + + source_dir = os.environ["SM_SOURCE_DIR"] + assert source_dir == "/opt/ml/input/data/code" + sm_drivers_dir = os.environ["SM_DRIVER_DIR"] + assert sm_drivers_dir == "/opt/ml/input/data/sm_drivers/drivers" + + entry_script = os.environ["SM_ENTRY_SCRIPT"] + assert entry_script != None + + python = sys.executable + + command = [python, entry_script] + print(f"Running command: {command}") + subprocess.run(command, check=True) + + +if __name__ == "__main__": + print("Running custom driver script") + main() + print("Finished running custom driver script") diff --git a/tests/data/modules/scripts/entry_script.py b/tests/data/modules/scripts/entry_script.py new file mode 100644 index 0000000000..3c972bd956 --- /dev/null +++ b/tests/data/modules/scripts/entry_script.py @@ -0,0 +1,19 @@ +import json +import os +import time + + +def main(): + hps = json.loads(os.environ["SM_HPS"]) + assert hps != None + print(f"Hyperparameters: {hps}") + + print("Running pseudo training script") + for epochs in range(hps["epochs"]): + print(f"Epoch: {epochs}") + time.sleep(1) + print("Finished running pseudo training script") + + +if __name__ == "__main__": + main() diff --git a/tests/integ/sagemaker/modules/train/test_model_trainer.py b/tests/integ/sagemaker/modules/train/test_model_trainer.py index cd298402b2..cb5cfb10f1 100644 --- a/tests/integ/sagemaker/modules/train/test_model_trainer.py +++ b/tests/integ/sagemaker/modules/train/test_model_trainer.py @@ -17,7 +17,7 @@ from sagemaker.modules.train import ModelTrainer from sagemaker.modules.configs import SourceCode, Compute -from sagemaker.modules.distributed import MPI, Torchrun +from sagemaker.modules.distributed import MPI, Torchrun, DistributedConfig EXPECTED_HYPERPARAMETERS = { "integer": 1, @@ -106,3 +106,35 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session): ) model_trainer.train() + + +def test_custom_distributed_driver(modules_sagemaker_session): + class CustomDriver(DistributedConfig): + process_count_per_node: int = None + + @property + def driver_dir(self) -> str: + return f"{DATA_DIR}/modules/custom_drivers" + + @property + def driver_script(self) -> str: + return "driver.py" + + source_code = SourceCode( + source_dir=f"{DATA_DIR}/modules/scripts", + entry_script="entry_script.py", + ) + + hyperparameters = {"epochs": 10} + + custom_driver = CustomDriver(process_count_per_node=2) + + model_trainer = ModelTrainer( + sagemaker_session=modules_sagemaker_session, + training_image=DEFAULT_CPU_IMAGE, + hyperparameters=hyperparameters, + source_code=source_code, + distributed=custom_driver, + base_job_name="custom-distributed-driver", + ) + model_trainer.train() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py index 30d6dfdf6c..a3f54ad439 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/scripts/test_enviornment.py @@ -21,12 +21,10 @@ from sagemaker.modules.train.container_drivers.scripts.environment import ( set_env, - log_key_value, log_env_variables, - mask_sensitive_info, HIDDEN_VALUE, ) -from sagemaker.modules.train.container_drivers.utils import safe_serialize, safe_deserialize +from sagemaker.modules.train.container_drivers.common.utils import safe_serialize, safe_deserialize RESOURCE_CONFIG = dict( current_host="algo-1", @@ -75,6 +73,15 @@ }, } +SOURCE_CODE = { + "source_dir": "code", + "entry_script": "train.py", +} + +DISTRIBUTED_CONFIG = { + "process_count_per_node": 2, +} + OUTPUT_FILE = os.path.join(os.path.dirname(__file__), "sm_training.env") # flake8: noqa @@ -89,6 +96,10 @@ export SM_LOG_LEVEL='20' export SM_MASTER_ADDR='algo-1' export SM_MASTER_PORT='7777' +export SM_SOURCE_DIR='/opt/ml/input/data/code' +export SM_ENTRY_SCRIPT='train.py' +export SM_DRIVER_DIR='/opt/ml/input/data/sm_drivers/drivers' +export SM_DISTRIBUTED_CONFIG='{"process_count_per_node": 2}' export SM_CHANNEL_TRAIN='/opt/ml/input/data/train' export SM_CHANNEL_VALIDATION='/opt/ml/input/data/validation' export SM_CHANNELS='["train", "validation"]' @@ -112,6 +123,14 @@ """ +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_source_code_json", + return_value=SOURCE_CODE, +) +@patch( + "sagemaker.modules.train.container_drivers.scripts.environment.read_distributed_json", + return_value=DISTRIBUTED_CONFIG, +) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_cpus", return_value=8) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_gpus", return_value=0) @patch("sagemaker.modules.train.container_drivers.scripts.environment.num_neurons", return_value=0) @@ -124,7 +143,13 @@ side_effect=safe_deserialize, ) def test_set_env( - mock_safe_deserialize, mock_safe_serialize, mock_num_cpus, mock_num_gpus, mock_num_neurons + mock_safe_deserialize, + mock_safe_serialize, + mock_num_neurons, + mock_num_gpus, + mock_num_cpus, + mock_read_distributed_json, + mock_read_source_code_json, ): with patch.dict(os.environ, {"TRAINING_JOB_NAME": "test-job"}): set_env( @@ -137,6 +162,8 @@ def test_set_env( mock_num_cpus.assert_called_once() mock_num_gpus.assert_called_once() mock_num_neurons.assert_called_once() + mock_read_distributed_json.assert_called_once() + mock_read_source_code_json.assert_called_once() with open(OUTPUT_FILE, "r") as f: env_file = f.read().strip() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py index a1a84da1ab..a752360981 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_driver.py @@ -15,13 +15,14 @@ import os import sys +import json from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() sys.modules["mpi_utils"] = MagicMock() -from sagemaker.modules.train.container_drivers import mpi_driver # noqa: E402 +from sagemaker.modules.train.container_drivers.drivers import mpi_driver # noqa: E402 DUMMY_MPI_COMMAND = [ @@ -40,12 +41,7 @@ "script.py", ] -DUMMY_SOURCE_CODE = { - "source_code": "source_code", - "entry_script": "script.py", -} DUMMY_DISTRIBUTED = { - "_type": "mpi", "process_count_per_node": 2, "mpi_additional_options": [ "--verbose", @@ -62,17 +58,18 @@ "SM_HOSTS": '["algo-1", "algo-2"]', "SM_MASTER_ADDR": "algo-1", "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "/opt/ml/input/data/code/script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_env_vars_to_file") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.start_sshd_daemon") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_master_node") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_worker_node") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.hyperparameters_to_cli_args") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_mpirun_command") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.execute_commands") def test_mpi_driver_worker( mock_execute_commands, mock_get_mpirun_command, @@ -81,12 +78,8 @@ def test_mpi_driver_worker( mock_bootstrap_master_node, mock_start_sshd_daemon, mock_write_env_vars_to_file, - mock_read_source_code_json, - mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] - mock_read_source_code_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mpi_driver.main() @@ -106,19 +99,20 @@ def test_mpi_driver_worker( "SM_HOSTS": '["algo-1", "algo-2"]', "SM_MASTER_ADDR": "algo-1", "SM_HOST_COUNT": "2", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_distributed_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.read_source_code_json") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_env_vars_to_file") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.start_sshd_daemon") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_master_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.bootstrap_worker_node") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_process_count") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.hyperparameters_to_cli_args") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.get_mpirun_command") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.execute_commands") -@patch("sagemaker.modules.train.container_drivers.mpi_driver.write_status_file_to_workers") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_env_vars_to_file") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.start_sshd_daemon") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_master_node") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.bootstrap_worker_node") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_process_count") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.hyperparameters_to_cli_args") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.get_mpirun_command") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.execute_commands") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_driver.write_status_file_to_workers") def test_mpi_driver_master( mock_write_status_file_to_workers, mock_execute_commands, @@ -129,12 +123,8 @@ def test_mpi_driver_master( mock_bootstrap_master_node, mock_start_sshd_daemon, mock_write_env_vars_to_file, - mock_read_source_code_config_json, - mock_read_distributed_json, ): mock_hyperparameters_to_cli_args.return_value = [] - mock_read_source_code_config_json.return_value = DUMMY_SOURCE_CODE - mock_read_distributed_json.return_value = DUMMY_DISTRIBUTED mock_get_mpirun_command.return_value = DUMMY_MPI_COMMAND mock_get_process_count.return_value = 2 mock_execute_commands.return_value = (0, "") diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py index 2328b1ace5..6c9f2545f0 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_mpi_utils.py @@ -27,7 +27,7 @@ mock_utils.get_python_executable = Mock(return_value="/usr/bin/python") with patch.dict("sys.modules", {"utils": mock_utils}): - from sagemaker.modules.train.container_drivers.mpi_utils import ( + from sagemaker.modules.train.container_drivers.drivers.mpi_utils import ( CustomHostKeyPolicy, _can_connect, write_status_file_to_workers, @@ -65,7 +65,7 @@ def test_custom_host_key_policy_invalid_hostname(): @patch("paramiko.SSHClient") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") def test_can_connect_success(mock_logger, mock_ssh_client): """Test successful SSH connection.""" mock_client = Mock() @@ -81,7 +81,7 @@ def test_can_connect_success(mock_logger, mock_ssh_client): @patch("paramiko.SSHClient") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") def test_can_connect_failure(mock_logger, mock_ssh_client): """Test SSH connection failure.""" mock_client = Mock() @@ -97,7 +97,7 @@ def test_can_connect_failure(mock_logger, mock_ssh_client): @patch("subprocess.run") -@patch("sagemaker.modules.train.container_drivers.mpi_utils.logger") +@patch("sagemaker.modules.train.container_drivers.drivers.mpi_utils.logger") def test_write_status_file_to_workers_failure(mock_logger, mock_run): """Test failed status file writing to workers with retry timeout.""" mock_run.side_effect = subprocess.CalledProcessError(1, "ssh") diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py index 4cff07a0c0..bfd26001c4 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_torchrun_driver.py @@ -15,38 +15,36 @@ import os import sys +import json from unittest.mock import patch, MagicMock sys.modules["utils"] = MagicMock() -from sagemaker.modules.train.container_drivers import torchrun_driver # noqa: E402 +from sagemaker.modules.train.container_drivers.drivers import torchrun_driver # noqa: E402 -DUMMY_SOURCE_CODE = { - "source_code": "source_code", - "entry_script": "script.py", -} - -DUMMY_distributed = {"_type": "torchrun", "process_count_per_node": 2} +DUMMY_DISTRIBUTED = {"process_count_per_node": 2} @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) def test_get_base_pytorch_command_torchrun(mock_pytorch_version, mock_get_python_executable): assert torchrun_driver.get_base_pytorch_command() == ["torchrun"] @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_python_executable", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_python_executable", return_value="python3", ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(1, 8) + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + return_value=(1, 8), ) def test_get_base_pytorch_command_torch_distributed_launch( mock_pytorch_version, mock_get_python_executable @@ -62,38 +60,29 @@ def test_get_base_pytorch_command_torch_distributed_launch( "SM_CURRENT_INSTANCE_TYPE": "ml.p4d.24xlarge", "SM_NETWORK_INTERFACE_NAME": "eth0", "SM_HOST_COUNT": "1", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", - "/opt/ml/input/data/code", -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_process_count", + return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", - return_value=DUMMY_SOURCE_CODE, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", - return_value=DUMMY_distributed, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_single_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_json, - mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, mock_get_process_count, @@ -102,7 +91,7 @@ def test_create_commands_single_node( "torchrun", "--nnodes=1", "--nproc_per_node=2", - "/opt/ml/input/data/code/script.py", + "script.py", ] command = torchrun_driver.create_commands() @@ -118,38 +107,29 @@ def test_create_commands_single_node( "SM_MASTER_ADDR": "algo-1", "SM_MASTER_PORT": "7777", "SM_CURRENT_HOST_RANK": "0", + "SM_HPS": json.dumps({}), + "SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED), + "SM_ENTRY_SCRIPT": "script.py", }, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.USER_CODE_PATH", - "/opt/ml/input/data/code", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_process_count", + return_value=2, ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_process_count", return_value=2 + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.pytorch_version", + return_value=(2, 0), ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.pytorch_version", return_value=(2, 0) -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.get_base_pytorch_command", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.get_base_pytorch_command", return_value=["torchrun"], ) @patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_source_code_json", - return_value=DUMMY_SOURCE_CODE, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.read_distributed_json", - return_value=DUMMY_distributed, -) -@patch( - "sagemaker.modules.train.container_drivers.torchrun_driver.hyperparameters_to_cli_args", + "sagemaker.modules.train.container_drivers.drivers.torchrun_driver.hyperparameters_to_cli_args", return_value=[], ) def test_create_commands_multi_node( mock_hyperparameters_to_cli_args, - mock_read_distributed_json, - mock_read_source_code_json, mock_get_base_pytorch_command, mock_pytorch_version, mock_get_process_count, @@ -161,7 +141,7 @@ def test_create_commands_multi_node( "--master_addr=algo-1", "--master_port=7777", "--node_rank=0", - "/opt/ml/input/data/code/script.py", + "script.py", ] command = torchrun_driver.create_commands() diff --git a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py index aba97996b0..beff06e8d8 100644 --- a/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py +++ b/tests/unit/sagemaker/modules/train/container_drivers/test_utils.py @@ -12,11 +12,13 @@ # language governing permissions and limitations under the License. """Container Utils Unit Tests.""" from __future__ import absolute_import +import os -from sagemaker.modules.train.container_drivers.utils import ( +from sagemaker.modules.train.container_drivers.common.utils import ( safe_deserialize, safe_serialize, hyperparameters_to_cli_args, + get_process_count, ) SM_HPS = { @@ -119,3 +121,18 @@ def test_safe_serialize_empty_data(): assert safe_serialize("") == "" assert safe_serialize([]) == "[]" assert safe_serialize({}) == "{}" + + +def test_get_process_count(): + assert get_process_count() == 1 + assert get_process_count(2) == 2 + os.environ["SM_NUM_GPUS"] = "4" + assert get_process_count() == 4 + os.environ["SM_NUM_GPUS"] = "0" + os.environ["SM_NUM_NEURONS"] = "8" + assert get_process_count() == 8 + os.environ["SM_NUM_NEURONS"] = "0" + assert get_process_count() == 1 + del os.environ["SM_NUM_GPUS"] + del os.environ["SM_NUM_NEURONS"] + assert get_process_count() == 1 diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 093da20ab8..8dc4d754b7 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -65,7 +65,7 @@ ) from sagemaker.modules.distributed import Torchrun, SMP, MPI from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg -from sagemaker.modules.templates import EXEUCTE_TORCHRUN_DRIVER, EXECUTE_MPI_DRIVER +from sagemaker.modules.templates import EXEUCTE_DISTRIBUTED_DRIVER from tests.unit import DATA_DIR DEFAULT_BASE_NAME = "dummy-image-job" @@ -410,7 +410,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ { "source_code": DEFAULT_SOURCE_CODE, "distributed": Torchrun(), - "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), "expected_hyperparameters": {}, }, { @@ -423,7 +425,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ tensor_parallel_degree=5, ) ), - "expected_template": EXEUCTE_TORCHRUN_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="Torchrun", driver_script="torchrun_driver.py" + ), "expected_hyperparameters": { "mp_parameters": json.dumps( { @@ -440,7 +444,9 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_ "distributed": MPI( custom_mpi_options=["-x", "VAR1", "-x", "VAR2"], ), - "expected_template": EXECUTE_MPI_DRIVER, + "expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format( + driver_name="MPI", driver_script="mpi_driver.py" + ), "expected_hyperparameters": {}, }, ], @@ -497,21 +503,15 @@ def test_train_with_distributed_config( assert os.path.exists(expected_runner_json_path) with open(expected_runner_json_path, "r") as f: runner_json_content = f.read() - assert test_case["distributed"].model_dump(exclude_none=True) == ( - json.loads(runner_json_content) - ) + assert test_case["distributed"].model_dump() == (json.loads(runner_json_content)) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump(exclude_none=True) == ( - json.loads(source_code_json_content) - ) + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) assert os.path.exists(expected_source_code_json_path) with open(expected_source_code_json_path, "r") as f: source_code_json_content = f.read() - assert test_case["source_code"].model_dump(exclude_none=True) == ( - json.loads(source_code_json_content) - ) + assert test_case["source_code"].model_dump() == (json.loads(source_code_json_content)) finally: shutil.rmtree(tmp_dir.name) assert not os.path.exists(tmp_dir.name)