diff --git a/src/accelerate/commands/config/cluster.py b/src/accelerate/commands/config/cluster.py index a653b9d6a50..8cb61b84651 100644 --- a/src/accelerate/commands/config/cluster.py +++ b/src/accelerate/commands/config/cluster.py @@ -371,6 +371,74 @@ def get_cluster_input(): "What is the name of the function in your script that should be launched in all parallel scripts? [main]: ", default="main", ) + + else: + main_training_function = "main" + + if distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_GPU]: + device_type = str(distributed_type).split(".")[1].replace("MULTI_", "") + "(s)" + num_processes = _ask_field( + f"How many {device_type} should be used for distributed training? [1]:", + lambda x: int(x), + default=1, + error_message="Please enter an integer.", + ) + elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: + num_processes = _ask_field( + "How many GPU(s) should be used for distributed training? [1]:", + lambda x: int(x), + default=1, + error_message="Please enter an integer.", + ) + elif distributed_type == DistributedType.TPU: + num_processes = _ask_field( + "How many TPU core(s) should be used for distributed training (if using pods, on each pod)? [8]:", + lambda x: int(x), + default=8, + error_message="Please enter an integer.", + ) + else: + num_processes = 1 + + if distributed_type in [DistributedType.MULTI_GPU, DistributedType.NO] and not use_cpu: + gpu_ids = _ask_field( + "What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:", + default="all", + ) + + if distributed_type != DistributedType.TPU: + if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config: + mixed_precision = "no" + else: + mixed_precision = _ask_options( + "Do you wish to use FP16 or BF16 (mixed precision)?", + ["no", "fp16", "bf16"], + _convert_mixed_precision, + ) + else: + mixed_precision = _ask_options( + "Do you wish to use BF16 (mixed precision)?", + ["no", "bf16"], + _convert_mixed_precision, + ) + + if use_dynamo and mixed_precision == "no" and not use_cpu: + print( + "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." + ) + + downcast_bf16 = "no" + tpu_vm = None + tpu_env = [] + use_cluster = False + tpu_use_sudo = False + + if distributed_type == DistributedType.TPU: + if mixed_precision == "bf16": + downcast_bf16 = _ask_field( + "Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no" + ) + use_cluster = _ask_field( "Are you using a TPU cluster? [yes/NO]: ", _convert_yes_no_to_bool, @@ -426,60 +494,20 @@ def get_cluster_input(): default=False, error_message="Please enter yes or no.", ) - - else: - main_training_function = "main" - - if distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_GPU, DistributedType.TPU]: - machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "") - if machine_type == "TPU": - machine_type += " cores" - else: - machine_type += "(s)" - num_processes = _ask_field( - f"How many {machine_type} should be used for distributed training? [1]:", - lambda x: int(x), - default=1, - error_message="Please enter an integer.", - ) - elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: - num_processes = _ask_field( - "How many GPU(s) should be used for distributed training? [1]:", - lambda x: int(x), - default=1, - error_message="Please enter an integer.", - ) - else: - num_processes = 1 - - if distributed_type in [DistributedType.MULTI_GPU, DistributedType.NO] and not use_cpu: - gpu_ids = _ask_field( - "What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:", - default="all", - ) - - if distributed_type != DistributedType.TPU: - if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config: - mixed_precision = "no" - else: - mixed_precision = _ask_options( - "Do you wish to use FP16 or BF16 (mixed precision)?", - ["no", "fp16", "bf16"], - _convert_mixed_precision, + tpu_use_sudo = _ask_field( + "To run a python script in your TPU environment should `sudo` be used? [yes/NO]: ", + default=None, + error_message="Please enter yes or no.", ) - else: - mixed_precision = "no" + tpu_vm = _ask_field( + "If not using an instance group, what are the names of the Compute VM instances to be used, seperated by a comma: ", + default="", + ).split(",") - if use_dynamo and mixed_precision == "no" and not use_cpu: - print( - "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." - ) - - downcast_bf16 = "no" - if distributed_type == DistributedType.TPU and mixed_precision == "bf16": - downcast_bf16 = _ask_field( - "Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no" - ) + tpu_env = _ask_field( + "What environment variables do you wish to set in each pod, seperated by a comma: ", + default="", + ).split(",") return ClusterConfig( compute_environment=ComputeEnvironment.LOCAL_MACHINE, @@ -501,6 +529,10 @@ def get_cluster_input(): same_network=same_network, tpu_name=tpu_name, tpu_zone=tpu_zone, + tpu_use_sudo=tpu_use_sudo, + tpu_vm=tpu_vm, + tpu_env=tpu_env, + tpu_cluster=use_cluster, commands=commands, command_file=command_file, dynamo_backend=dynamo_backend, diff --git a/src/accelerate/commands/config/config_args.py b/src/accelerate/commands/config/config_args.py index ba492802e46..d34a54a5d16 100644 --- a/src/accelerate/commands/config/config_args.py +++ b/src/accelerate/commands/config/config_args.py @@ -161,8 +161,12 @@ class ClusterConfig(BaseConfig): # args for TPU pods tpu_name: str = None tpu_zone: str = None + tpu_cluster: bool = False + tpu_use_sudo: bool = False command_file: str = None commands: List[str] = None + tpu_vm: List[str] = None + tpu_env: List[str] = None def __post_init__(self): if self.deepspeed_config is None: diff --git a/src/accelerate/commands/launch.py b/src/accelerate/commands/launch.py index 8eefb9d7ff3..1df3ffe6cbd 100644 --- a/src/accelerate/commands/launch.py +++ b/src/accelerate/commands/launch.py @@ -44,6 +44,7 @@ is_sagemaker_available, is_torch_version, patch_environment, + prepare_tpu, ) from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS from accelerate.utils.dataclasses import SageMakerDistributedType @@ -283,7 +284,7 @@ def launch_command_parser(subparsers=None): help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.", ) - # tpu arguments + # TPU arguments tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.") tpu_args.add_argument( "--main_training_function", @@ -296,6 +297,37 @@ def launch_command_parser(subparsers=None): action="store_true", help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.", ) + tpu_args.add_argument( + "--tpu_cluster", + action="store_true", + help="Whether to use a GCP TPU pod for training.", + ) + tpu_args.add_argument( + "--no_tpu_cluster", + action="store_false", + dest="tpu_cluster", + help="Should not be passed explicitly, this is for internal use only.", + ) + tpu_args.add_argument( + "--tpu_use_sudo", + action="store_true", + help="Whether to use sudo when running the TPU training script.", + ) + tpu_args.add_argument( + "--vm", + type=str, + action="append", + help=( + "List of single Compute VM instance names. " + "If not provided we assume usage of instance groups. For TPU pods." + ), + ) + tpu_args.add_argument( + "--env", + type=str, + action="append", + help="List of environment variables to set on the Compute VM instances. For TPU pods.", + ) # DeepSpeed arguments deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.") @@ -654,7 +686,7 @@ def multi_gpu_launcher(args): raise NotImplementedError("Multi-node training requires pytorch>=1.9.0") debug = getattr(args, "debug", False) - args = _filter_args(args) + args = _filter_args(args, distrib_run.get_args_parser()) with patch_environment(**current_env): try: distrib_run.run(args) @@ -776,7 +808,7 @@ def deepspeed_launcher(args): raise NotImplementedError("Multi-node training requires pytorch>=1.9.0") debug = getattr(args, "debug", False) - args = _filter_args(args) + args = _filter_args(args, distrib_run.get_args_parser()) with patch_environment(**current_env): try: distrib_run.run(args) @@ -795,13 +827,7 @@ def tpu_launcher(args): if args.no_python: raise ValueError("--no_python cannot be used with TPU launcher") - if args.mixed_precision == "bf16": - if args.downcast_bf16: - current_env["XLA_USE_BF16"] = "0" - current_env["XLA_DOWNCAST_BF16"] = "1" - else: - current_env["XLA_USE_BF16"] = "1" - current_env["XLA_DOWNCAST_BF16"] = "0" + args, current_env = prepare_tpu(args, current_env) if args.module: mod_name = args.training_script @@ -826,6 +852,59 @@ def tpu_launcher(args): xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes) +def tpu_pod_launcher(args): + from torch_xla.distributed import xla_dist + + current_env = {} + args, current_env = prepare_tpu(args, current_env, True) + debug = getattr(args, "debug", False) + + training_script = args.training_script + training_script_args = args.training_script_args + new_args = _filter_args( + args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"] + ) + + if args.tpu_use_sudo: + new_cmd = ["sudo"] + else: + new_cmd = [] + + new_cmd += [ + "accelerate-launch", + "--tpu", + "--no_tpu_cluster", + "--num_processes", + str(args.num_processes), + "--main_training_function", + str(args.main_training_function), + training_script, + ] + training_script_args + + new_args.positional = new_cmd + bad_flags = "" + for arg in vars(new_args): + if arg.startswith("docker_"): + value = getattr(new_args, arg) + if value != "" and value is not None: + bad_flags += f'{arg}="{value}"\n' + if bad_flags != "": + raise ValueError( + f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}" + ) + new_args.env = [f"{k}={v}" for k, v in current_env.items()] + new_args.env.append("ACCELERATE_IN_TPU_POD=1") + try: + xla_dist.resolve_and_execute(new_args) + except: + if is_rich_available() and debug: + console = get_console() + console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]") + console.print_exception(suppress=[__file__], show_locals=False) + else: + raise + + def _convert_nargs_to_dict(nargs: List[str]) -> Dict[str, str]: if len(nargs) < 0: return {} @@ -1001,6 +1080,7 @@ def launch_command(args): if ( not args.multi_gpu and not args.tpu + and not args.tpu_cluster and not args.mps and not args.use_deepspeed and not args.use_fsdp @@ -1009,6 +1089,7 @@ def launch_command(args): args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED args.multi_gpu = defaults.distributed_type == DistributedType.MULTI_GPU args.tpu = defaults.distributed_type == DistributedType.TPU + args.tpu_cluster = defaults.tpu_cluster and args.tpu args.use_fsdp = defaults.distributed_type == DistributedType.FSDP args.mps = defaults.distributed_type == DistributedType.MPS args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM @@ -1097,7 +1178,10 @@ def launch_command(args): elif args.multi_gpu and not args.cpu: multi_gpu_launcher(args) elif args.tpu and not args.cpu: - tpu_launcher(args) + if args.tpu_cluster: + tpu_pod_launcher(args) + else: + tpu_launcher(args) elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: sagemaker_launcher(defaults, args) else: diff --git a/src/accelerate/commands/tpu.py b/src/accelerate/commands/tpu.py index 6b90770c750..0db53363ec4 100644 --- a/src/accelerate/commands/tpu.py +++ b/src/accelerate/commands/tpu.py @@ -51,6 +51,11 @@ def tpu_command_parser(subparsers=None): help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.", ) pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.") + pod_args.add_argument( + "--use_alpha", + action="store_true", + help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.", + ) pod_args.add_argument( "--command_file", default=None, diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 87f891021f0..d02ea7abbcd 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -98,7 +98,7 @@ HfDeepSpeedConfig, ) -from .launch import PrepareForLaunch, _filter_args, get_launch_prefix +from .launch import PrepareForLaunch, _filter_args, get_launch_prefix, prepare_tpu from .megatron_lm import ( AbstractTrainStep, BertTrainStep, diff --git a/src/accelerate/utils/launch.py b/src/accelerate/utils/launch.py index c65df08dcfe..e888404ca6d 100644 --- a/src/accelerate/utils/launch.py +++ b/src/accelerate/utils/launch.py @@ -35,14 +35,11 @@ def get_launch_prefix(): return cmd -def _filter_args(args): +def _filter_args(args, parser, default_args=[]): """ Filters out all `accelerate` specific args """ - if is_torch_version(">=", "1.9.0"): - import torch.distributed.run as distrib_run - distrib_args = distrib_run.get_args_parser() - new_args, _ = distrib_args.parse_known_args() + new_args, _ = parser.parse_known_args(default_args) for key, value in vars(args).items(): if key in vars(new_args).keys(): @@ -50,6 +47,27 @@ def _filter_args(args): return new_args +def prepare_tpu(args, current_env, pod=False): + """ + Prepares and returns an environment with the correct TPU environment variables. + """ + current_env["XLA_USE_BF16"] = "0" + current_env["XLA_DOWNCAST_BF16"] = "0" + if args.mixed_precision == "bf16": + if args.downcast_bf16: + current_env["XLA_DOWNCAST_BF16"] = "1" + else: + current_env["XLA_USE_BF16"] = "1" + if pod: + # Take explicit args and set them up for XLA + args.vm = args.tpu_vm + args.tpu = args.tpu_name + # elif not os.environ.get("ACCELERATE_IN_TPU_POD", "0") == "1": + # # `xla_dist` will take care of this on pods + # current_env["XRT_TPU_CONFIG"] = "localservice;0;localhost:51011" + return args, current_env + + def env_var_path_add(env_var_name, path_to_add): """ Extends a path-based environment variable's value with a new path and returns the updated value. It's up to the