Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallelism v2 #3335

Merged
merged 20 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
is_deepspeed,
keep_placeholders=True,
).lstrip('/')
assert state.sharded_ckpt_prefix_dir is not None
remote_prefix = state.sharded_ckpt_prefix_dir
assert state.fsdp_config is not None
remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir']
assert remote_prefix is not None
ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename)
remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp)
Expand Down
249 changes: 183 additions & 66 deletions composer/core/state.py

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions composer/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Distributed training."""

from composer.distributed.deepspeed import fix_batch_precision_for_deepspeed, parse_deepspeed_config
from composer.distributed.dist_strategy import (
DDPSyncStrategy,
ddp_sync_context,
prepare_ddp_module,
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.mosaic_fsdp import set_fsdp_default

__all__ = [
'fix_batch_precision_for_deepspeed',
'parse_deepspeed_config',
'DDPSyncStrategy',
'ddp_sync_context',
'prepare_ddp_module',
'prepare_fsdp_module',
'prepare_tp_module',
'set_fsdp_default',
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from composer.core import Batch, Precision, State
from composer.utils import dist, map_collection

__all__ = ['_fix_batch_precision_for_deepspeed', '_parse_deepspeed_config']
__all__ = ['fix_batch_precision_for_deepspeed', 'parse_deepspeed_config']


def _add_batch_config(config: Dict[str, Any], state: State):
Expand Down Expand Up @@ -105,7 +105,7 @@ def _add_precision_config(config: Dict[str, Any], state: State):
config['bf16'] = cast(Dict[str, Any], {'enabled': True})


def _parse_deepspeed_config(
def parse_deepspeed_config(
config: Dict[str, Any],
state: State,
) -> Dict[str, Any]:
Expand Down Expand Up @@ -160,7 +160,7 @@ def _convert_fp32_tensor_to_bf16(tensor: torch.Tensor):
return tensor


def _fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch:
def fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch:
"""Ensures that a batch is properly formatted for DeepSpeed precisions, if active.

.. note:: Just because the precision is set to FP16 doesn't mean the entire batch can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

from composer.core import Precision, State
from composer.devices import Device
from composer.trainer.meta_safe_apply import meta_safe_apply
from composer.trainer.mosaic_fsdp import patch_pytorch
from composer.trainer.mosaic_fsdp_utils import (
from composer.distributed.meta_safe_apply import meta_safe_apply
from composer.distributed.mosaic_fsdp import (
BACKWARD_PREFETCH_MAP,
SHARDING_MAP,
_set_custom_fsdp_module_kwargs,
get_cpu_offload,
get_mixed_precision,
set_custom_fsdp_module_kwargs,
)
from composer.utils import StringEnum, dist, ensure_tuple

__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module']
__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module', 'prepare_tp_module']

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,35 +141,6 @@ def prepare_ddp_module(module: torch.nn.Module, find_unused_parameters: bool) ->
)


def set_fsdp_default(fsdp_config: Dict[str, Any]):
"""Modify fsdp_config to set default values for missing keys."""
fsdp_config.setdefault('activation_checkpointing', False)
fsdp_config.setdefault('activation_checkpointing_reentrant', True)
fsdp_config.setdefault('activation_cpu_offload', False)
fsdp_config.setdefault('te_checkpoint_wrapper', False)
fsdp_config.setdefault('te_shard_fp8_weight', False)
fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST')
fsdp_config.setdefault('backward_prefetch_limit', 1)
fsdp_config.setdefault('cpu_offload', False)
fsdp_config.setdefault('forward_prefetch', False)
fsdp_config.setdefault('forward_prefetch_limit', 1)
fsdp_config.setdefault('ignored_modules', None)
fsdp_config.setdefault('keep_low_precision_grads', False)
fsdp_config.setdefault('limit_all_gathers', True)
fsdp_config.setdefault('load_monolith_rank0_only', False)
fsdp_config.setdefault('load_planner', None)
fsdp_config.setdefault('mixed_precision', 'DEFAULT')
fsdp_config.setdefault('process_group', None)
fsdp_config.setdefault('save_planner', None)
fsdp_config.setdefault('sharded_ckpt_prefix_dir', 'ep{epoch}-ba{batch}')
fsdp_config.setdefault('sharding_strategy', 'FULL_SHARD')
fsdp_config.setdefault('state_dict_type', 'full')
fsdp_config.setdefault('sync_module_states', False)
fsdp_config.setdefault('use_orig_params', True)
fsdp_config.setdefault('verbose', False)
return fsdp_config


def _recreate_fsdp_param_groups_from_unwrapped_opt_info(
fsdp_wrapped_named_params: Iterator[Tuple[str, torch.nn.Parameter]],
non_wrapped_param_names_to_group_num: Dict[str, int],
Expand Down Expand Up @@ -209,6 +179,22 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info(
return [group_num_to_optimizer_info[num] for num in sorted(group_num_to_optimizer_info.keys())]


def prepare_tp_module(
model: torch.nn.Module,
tp_config: Dict[str, Any],
) -> None:
"""Prepare a module (assumed ComposerModel) for use with tensor parallel."""
from torch.distributed.tensor.parallel import parallelize_module

device_mesh = tp_config['device_mesh']
layer_plan = tp_config['layer_plan']
parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan=layer_plan,
)


def prepare_fsdp_module(
model: torch.nn.Module,
optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]],
Expand All @@ -229,10 +215,6 @@ def prepare_fsdp_module(
auto_microbatching (bool, optional): Whether or not auto microbatching is enabled.
te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234.
"""
patch_pytorch()

set_fsdp_default(fsdp_config)

# Check sync_module_states is True for mixed initialization or HSDP
if fsdp_config['sync_module_states'] == False:
rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0
Expand Down Expand Up @@ -319,31 +301,26 @@ def sync_hook(*args):
sharding_strategy = SHARDING_MAP[sharding_map_key]

kwargs = {}
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'):
if 'device_mesh' in fsdp_config:
device_mesh_size = len(fsdp_config['device_mesh'])
if sharding_strategy in [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
] and device_mesh_size != 1:
raise ValueError(
f'FSDP sharding strategy {sharding_map_key.upper()} requires a device mesh '
f'of size 1 but got device mesh size of {device_mesh_size}.',
)
elif sharding_strategy in [
ShardingStrategy.HYBRID_SHARD,
ShardingStrategy._HYBRID_SHARD_ZERO2,
] and device_mesh_size != 2:
raise ValueError(
f'FSDP sharding strategy {sharding_map_key.upper()} requires a device mesh '
f'of size 2 but got device mesh size of {device_mesh_size}.',
)
from torch.distributed._tensor import init_device_mesh
kwargs['device_mesh'] = init_device_mesh(
'cuda',
tuple([int(x) for x in fsdp_config['device_mesh']]),
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0') and 'device_mesh' in fsdp_config:
if fsdp_config['process_group'] is not None:
warnings.warn(
'process_group and device_mesh are set for FSDP, so ignoring device_mesh. Please set process_group to None.',
)
else:
ndim = fsdp_config['device_mesh'].ndim
if ndim == 1 and sharding_strategy == ShardingStrategy.HYBRID_SHARD:
sharding_strategy = ShardingStrategy.FULL_SHARD
warnings.warn('HYBRID_SHARD is not supported with 1D device mesh. Using FULL_SHARD instead.')
elif ndim == 1 and sharding_strategy == ShardingStrategy._HYBRID_SHARD_ZERO2:
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
warnings.warn('_HYBRID_SHARD_ZERO2 is not supported with 1D device mesh. Using SHARD_GRAD_OP instead.')
elif ndim == 2 and sharding_strategy == ShardingStrategy.SHARD_GRAD_OP:
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
warnings.warn('SHARD_GRAD_OP is not supported with 2D device mesh. Using _HYBRID_SHARD_ZERO2 instead.')
elif ndim == 2 and sharding_strategy == ShardingStrategy.FULL_SHARD:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
warnings.warn('FULL_SHARD is not supported with 2D device mesh. Using HYBRID_SHARD instead.')
kwargs['device_mesh'] = fsdp_config['device_mesh']

cpu_offload = get_cpu_offload(cpu_offload=fsdp_config['cpu_offload'])

Expand Down Expand Up @@ -382,7 +359,7 @@ def sync_hook(*args):
process_group = None
if fsdp_config['process_group'] is not None:
process_group_dict = {'process_group': fsdp_config['process_group']}
process_group = _set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group']
process_group = set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group']
backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config['backward_prefetch'].upper()]
activation_checkpointing = fsdp_config['activation_checkpointing']
activation_cpu_offload = fsdp_config['activation_cpu_offload']
Expand Down Expand Up @@ -556,7 +533,7 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]:
elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
ret = obj.fsdp_wrap_fn(module)
if isinstance(ret, dict):
ret = _set_custom_fsdp_module_kwargs(ret, process_group_cache)
ret = set_custom_fsdp_module_kwargs(ret, process_group_cache)
if ret and auto_microbatching:
module.register_forward_hook(sync_hook)
module.register_full_backward_hook(sync_hook)
Expand Down
Loading
Loading