Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce TPU pod launcher #815

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
95e41a0
TPU pod launcher
muellerzr Nov 2, 2022
b6d3f79
Should be working now, just need final steps
muellerzr Nov 2, 2022
1c17bdc
Filter args
muellerzr Nov 2, 2022
adb2309
Remove redundancy
muellerzr Nov 2, 2022
9552778
Working WIP!
muellerzr Nov 7, 2022
962b771
Merge branch 'main' into tpu-pod-launch
muellerzr Nov 7, 2022
c694064
Fix arg
muellerzr Nov 7, 2022
f46396b
Merge branch 'tpu-pod-launch' of https://github.com/huggingface/accel…
muellerzr Nov 7, 2022
2d4b837
rm print
muellerzr Nov 7, 2022
b07d053
Try with no_tpu_cluster
muellerzr Nov 7, 2022
8164087
Switch to python3, use different branch
muellerzr Nov 7, 2022
ef36125
Try with just this
muellerzr Nov 8, 2022
c98c71d
With python
muellerzr Nov 8, 2022
0f0567d
It's working!
muellerzr Nov 15, 2022
5fdd72f
Merge branch 'main' into tpu-pod-launch
muellerzr Nov 15, 2022
e8b694b
Fixed up CLI, needs a change before final merge and ci redo is in
muellerzr Nov 15, 2022
f03fbd3
Merge branch 'main' into tpu-pod-launch
muellerzr Nov 15, 2022
60665b5
Merge with main
muellerzr Nov 15, 2022
f3ace09
Better doc
muellerzr Nov 15, 2022
9623e06
machine_type -> device_type
muellerzr Nov 17, 2022
7d3066f
Enable bf16 on TPUs through config
muellerzr Nov 17, 2022
f5eb40c
Rm XRT_TPU_CONFIG for now
muellerzr Dec 14, 2022
ccb5291
New version
muellerzr Dec 19, 2022
d8add9a
Merge branch 'tpu-pod-launch' of https://github.com/huggingface/accel…
muellerzr Dec 19, 2022
6bf5cb8
Add training function
muellerzr Dec 19, 2022
679f636
With sudo option
muellerzr Dec 19, 2022
d423b7b
Push fix
muellerzr Dec 19, 2022
7effca3
Update with alpha option
muellerzr Feb 7, 2023
8a0b5c4
Add use_sudo
muellerzr Feb 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 84 additions & 52 deletions src/accelerate/commands/config/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,74 @@ def get_cluster_input():
"What is the name of the function in your script that should be launched in all parallel scripts? [main]: ",
default="main",
)

else:
main_training_function = "main"

if distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_GPU]:
device_type = str(distributed_type).split(".")[1].replace("MULTI_", "") + "(s)"
num_processes = _ask_field(
f"How many {device_type} should be used for distributed training? [1]:",
lambda x: int(x),
default=1,
error_message="Please enter an integer.",
)
elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
num_processes = _ask_field(
"How many GPU(s) should be used for distributed training? [1]:",
lambda x: int(x),
default=1,
error_message="Please enter an integer.",
)
elif distributed_type == DistributedType.TPU:
num_processes = _ask_field(
"How many TPU core(s) should be used for distributed training (if using pods, on each pod)? [8]:",
lambda x: int(x),
default=8,
error_message="Please enter an integer.",
)
else:
num_processes = 1

if distributed_type in [DistributedType.MULTI_GPU, DistributedType.NO] and not use_cpu:
gpu_ids = _ask_field(
"What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:",
default="all",
)

if distributed_type != DistributedType.TPU:
if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:
mixed_precision = "no"
else:
mixed_precision = _ask_options(
"Do you wish to use FP16 or BF16 (mixed precision)?",
["no", "fp16", "bf16"],
_convert_mixed_precision,
)
else:
mixed_precision = _ask_options(
"Do you wish to use BF16 (mixed precision)?",
["no", "bf16"],
_convert_mixed_precision,
)

if use_dynamo and mixed_precision == "no" and not use_cpu:
print(
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
)

downcast_bf16 = "no"
tpu_vm = None
tpu_env = []
use_cluster = False
tpu_use_sudo = False

if distributed_type == DistributedType.TPU:
if mixed_precision == "bf16":
downcast_bf16 = _ask_field(
"Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no"
)

use_cluster = _ask_field(
"Are you using a TPU cluster? [yes/NO]: ",
_convert_yes_no_to_bool,
Expand Down Expand Up @@ -426,60 +494,20 @@ def get_cluster_input():
default=False,
error_message="Please enter yes or no.",
)

else:
main_training_function = "main"

if distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_GPU, DistributedType.TPU]:
machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
if machine_type == "TPU":
machine_type += " cores"
else:
machine_type += "(s)"
num_processes = _ask_field(
f"How many {machine_type} should be used for distributed training? [1]:",
lambda x: int(x),
default=1,
error_message="Please enter an integer.",
)
elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
num_processes = _ask_field(
"How many GPU(s) should be used for distributed training? [1]:",
lambda x: int(x),
default=1,
error_message="Please enter an integer.",
)
else:
num_processes = 1

if distributed_type in [DistributedType.MULTI_GPU, DistributedType.NO] and not use_cpu:
gpu_ids = _ask_field(
"What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:",
default="all",
)

if distributed_type != DistributedType.TPU:
if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config:
mixed_precision = "no"
else:
mixed_precision = _ask_options(
"Do you wish to use FP16 or BF16 (mixed precision)?",
["no", "fp16", "bf16"],
_convert_mixed_precision,
tpu_use_sudo = _ask_field(
"To run a python script in your TPU environment should `sudo` be used? [yes/NO]: ",
default=None,
error_message="Please enter yes or no.",
)
else:
mixed_precision = "no"
tpu_vm = _ask_field(
"If not using an instance group, what are the names of the Compute VM instances to be used, seperated by a comma: ",
default="",
).split(",")

if use_dynamo and mixed_precision == "no" and not use_cpu:
print(
"Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts."
)

downcast_bf16 = "no"
if distributed_type == DistributedType.TPU and mixed_precision == "bf16":
downcast_bf16 = _ask_field(
"Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no"
)
tpu_env = _ask_field(
"What environment variables do you wish to set in each pod, seperated by a comma: ",
default="",
).split(",")

return ClusterConfig(
compute_environment=ComputeEnvironment.LOCAL_MACHINE,
Expand All @@ -501,6 +529,10 @@ def get_cluster_input():
same_network=same_network,
tpu_name=tpu_name,
tpu_zone=tpu_zone,
tpu_use_sudo=tpu_use_sudo,
tpu_vm=tpu_vm,
tpu_env=tpu_env,
tpu_cluster=use_cluster,
commands=commands,
command_file=command_file,
dynamo_backend=dynamo_backend,
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/commands/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,12 @@ class ClusterConfig(BaseConfig):
# args for TPU pods
tpu_name: str = None
tpu_zone: str = None
tpu_cluster: bool = False
tpu_use_sudo: bool = False
command_file: str = None
commands: List[str] = None
tpu_vm: List[str] = None
tpu_env: List[str] = None

def __post_init__(self):
if self.deepspeed_config is None:
Expand Down
106 changes: 95 additions & 11 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
is_sagemaker_available,
is_torch_version,
patch_environment,
prepare_tpu,
)
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS
from accelerate.utils.dataclasses import SageMakerDistributedType
Expand Down Expand Up @@ -283,7 +284,7 @@ def launch_command_parser(subparsers=None):
help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.",
)

# tpu arguments
# TPU arguments
tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.")
tpu_args.add_argument(
"--main_training_function",
Expand All @@ -296,6 +297,37 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.",
)
tpu_args.add_argument(
"--tpu_cluster",
action="store_true",
help="Whether to use a GCP TPU pod for training.",
)
tpu_args.add_argument(
"--no_tpu_cluster",
action="store_false",
dest="tpu_cluster",
help="Should not be passed explicitly, this is for internal use only.",
)
tpu_args.add_argument(
"--tpu_use_sudo",
action="store_true",
help="Whether to use sudo when running the TPU training script.",
)
tpu_args.add_argument(
"--vm",
type=str,
action="append",
help=(
"List of single Compute VM instance names. "
"If not provided we assume usage of instance groups. For TPU pods."
),
)
tpu_args.add_argument(
"--env",
type=str,
action="append",
help="List of environment variables to set on the Compute VM instances. For TPU pods.",
)

# DeepSpeed arguments
deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.")
Expand Down Expand Up @@ -654,7 +686,7 @@ def multi_gpu_launcher(args):
raise NotImplementedError("Multi-node training requires pytorch>=1.9.0")

debug = getattr(args, "debug", False)
args = _filter_args(args)
args = _filter_args(args, distrib_run.get_args_parser())
with patch_environment(**current_env):
try:
distrib_run.run(args)
Expand Down Expand Up @@ -776,7 +808,7 @@ def deepspeed_launcher(args):
raise NotImplementedError("Multi-node training requires pytorch>=1.9.0")

debug = getattr(args, "debug", False)
args = _filter_args(args)
args = _filter_args(args, distrib_run.get_args_parser())
with patch_environment(**current_env):
try:
distrib_run.run(args)
Expand All @@ -795,13 +827,7 @@ def tpu_launcher(args):
if args.no_python:
raise ValueError("--no_python cannot be used with TPU launcher")

if args.mixed_precision == "bf16":
if args.downcast_bf16:
current_env["XLA_USE_BF16"] = "0"
current_env["XLA_DOWNCAST_BF16"] = "1"
else:
current_env["XLA_USE_BF16"] = "1"
current_env["XLA_DOWNCAST_BF16"] = "0"
args, current_env = prepare_tpu(args, current_env)

if args.module:
mod_name = args.training_script
Expand All @@ -826,6 +852,59 @@ def tpu_launcher(args):
xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)


def tpu_pod_launcher(args):
from torch_xla.distributed import xla_dist

current_env = {}
args, current_env = prepare_tpu(args, current_env, True)
debug = getattr(args, "debug", False)

training_script = args.training_script
training_script_args = args.training_script_args
new_args = _filter_args(
args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"]
)

if args.tpu_use_sudo:
new_cmd = ["sudo"]
else:
new_cmd = []

new_cmd += [
"accelerate-launch",
"--tpu",
"--no_tpu_cluster",
"--num_processes",
str(args.num_processes),
"--main_training_function",
str(args.main_training_function),
training_script,
] + training_script_args

new_args.positional = new_cmd
bad_flags = ""
for arg in vars(new_args):
if arg.startswith("docker_"):
value = getattr(new_args, arg)
if value != "" and value is not None:
bad_flags += f'{arg}="{value}"\n'
if bad_flags != "":
raise ValueError(
f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}"
)
new_args.env = [f"{k}={v}" for k, v in current_env.items()]
new_args.env.append("ACCELERATE_IN_TPU_POD=1")
try:
xla_dist.resolve_and_execute(new_args)
except:
if is_rich_available() and debug:
console = get_console()
console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]")
console.print_exception(suppress=[__file__], show_locals=False)
else:
raise


def _convert_nargs_to_dict(nargs: List[str]) -> Dict[str, str]:
if len(nargs) < 0:
return {}
Expand Down Expand Up @@ -1001,6 +1080,7 @@ def launch_command(args):
if (
not args.multi_gpu
and not args.tpu
and not args.tpu_cluster
and not args.mps
and not args.use_deepspeed
and not args.use_fsdp
Expand All @@ -1009,6 +1089,7 @@ def launch_command(args):
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
args.multi_gpu = defaults.distributed_type == DistributedType.MULTI_GPU
args.tpu = defaults.distributed_type == DistributedType.TPU
args.tpu_cluster = defaults.tpu_cluster and args.tpu
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
args.mps = defaults.distributed_type == DistributedType.MPS
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
Expand Down Expand Up @@ -1097,7 +1178,10 @@ def launch_command(args):
elif args.multi_gpu and not args.cpu:
multi_gpu_launcher(args)
elif args.tpu and not args.cpu:
tpu_launcher(args)
if args.tpu_cluster:
tpu_pod_launcher(args)
else:
tpu_launcher(args)
elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
sagemaker_launcher(defaults, args)
else:
Expand Down
5 changes: 5 additions & 0 deletions src/accelerate/commands/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def tpu_command_parser(subparsers=None):
help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.",
)
pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.")
pod_args.add_argument(
"--use_alpha",
action="store_true",
help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.",
)
pod_args.add_argument(
"--command_file",
default=None,
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
HfDeepSpeedConfig,
)

from .launch import PrepareForLaunch, _filter_args, get_launch_prefix
from .launch import PrepareForLaunch, _filter_args, get_launch_prefix, prepare_tpu
from .megatron_lm import (
AbstractTrainStep,
BertTrainStep,
Expand Down
Loading