Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallelism v2 #3335

Merged
merged 20 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,9 @@ def _save_checkpoint(self, state: State, logger: Logger):
is_deepspeed,
keep_placeholders=True,
).lstrip('/')
assert state.sharded_ckpt_prefix_dir is not None
remote_prefix = state.sharded_ckpt_prefix_dir
assert state.fsdp_config is not None
remote_prefix = state.fsdp_config['sharded_ckpt_prefix_dir']
assert remote_prefix is not None
ckpt_filename = checkpoint._TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
remote_file_name = os.path.join(pathlib.Path(remote_file_name).parent, remote_prefix, ckpt_filename)
remote_file_name = format_name_with_dist_and_time(remote_file_name, state.run_name, state.timestamp)
Expand Down
215 changes: 158 additions & 57 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.nn.modules.utils
from packaging import version
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
Expand All @@ -30,8 +31,6 @@
from torch.utils.data import DataLoader, Dataset
from torchmetrics import Metric

from composer.utils.warnings import VersionedDeprecationWarning

if version.parse(torch.__version__) >= version.parse('2.3.0'):
from torch.amp.grad_scaler import GradScaler # type: ignore
else:
Expand All @@ -44,6 +43,8 @@
from composer.core.time import Time, Timestamp, TimeUnit, ensure_time
from composer.devices import Device
from composer.utils import (
ParallelismType,
VersionedDeprecationWarning,
batch_get,
batch_set,
dist,
Expand Down Expand Up @@ -194,6 +195,79 @@ def _ensure_backwards_compatible_checkpointing(state_dict: Dict[str, Any]):
return state


def _create_device_mesh(
device: Device,
fsdp_config: Optional[Dict[str, Any]],
tp_config: Optional[Dict[str, Any]],
) -> Optional[DeviceMesh]:
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
# Device mesh has correctness issues before torch 2.3.0
return None

if fsdp_config is None:
return None

# Gather dimensions and names for the device mesh
dims: List[int] = []
names: List[str] = []
if fsdp_config['data_parallel_replicate_degree'] is not None:
dims.append(fsdp_config['data_parallel_replicate_degree'])
names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
dims.append(fsdp_config['data_parallel_shard_degree'])
names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
if tp_config is not None:
dims.append(tp_config['tensor_parallel_degree'])
names.append(ParallelismType.TENSOR_PARALLEL.value)

# Fill in the unspecified dimensions
product_of_dims = 1
unspecified_dim_names = []
for dim, name in zip(dims, names):
if dim != -1:
product_of_dims *= dim
else:
unspecified_dim_names.append(name)
if len(unspecified_dim_names) > 1:
raise ValueError(
f'Found multiple parallelism dimensions with -1: {unspecified_dim_names}. '
'Only one is allowed, which is set to fill the remaining dimensions.',
)
elif len(unspecified_dim_names) == 1:
if product_of_dims > dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} is greater than the product of the specified parallelism degrees '
f'{product_of_dims}. Please ensure the product of the specified parallelism degrees matches the world ',
f'size. Currently specified degrees are {names=}, {dims=}. One dimension can also be left as -1, which '
'will automatically be specified to ensure the product matches the world size.',
)
remaining_dimension = dist.get_world_size() // product_of_dims
if remaining_dimension * product_of_dims != dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} is not divisible by the product of the specified '
'parallelism degrees. Please ensure the product of the specified parallelism degrees '
'matches the world size.',
)
for i, dim in enumerate(dims):
if dim == -1:
dims[i] = remaining_dimension
log.info(f'Automatically setting {names[i]} to have parallelization degree {remaining_dimension}.')
break
else:
if product_of_dims != dist.get_world_size():
raise ValueError(
f'World size {dist.get_world_size()} does not equal the product of the specified parallelism degrees '
f'{product_of_dims}. Please ensure the product of the specified parallelism degrees matches the world ',
f'size. Currently specified degrees are {names=}, {dims=}. One dimension can also be left as -1, which '
'will automatically be specified to ensure the product matches the world size.',
)

device_type = device.name
if device_type == 'gpu':
device_type = 'cuda'

return init_device_mesh(device_type=device_type, mesh_shape=tuple(dims), mesh_dim_names=tuple(names))


_STATE_DICT_SERIALIZED_ATTRIBUTES = [
# List of attributes that are serialized with state_dict
# Only the attributes listed in state.serialized_attributes will actually be saved.
Expand Down Expand Up @@ -255,8 +329,7 @@ class State(Serializable):
algorithms (Algorithm | Sequence[Algorithm], optional): The algorithms used for training.
callbacks (Callback | Sequence[Callback], optional): The callbacks used for training.
deepspeed_config (Dict[str, Any], optional): The configuration dictionary for deepspeed.
fsdp_config (Dict[str, Any], optional): The configuration dictionary for FSDP.
fsdp_auto_wrap (bool, optional): Whether to automatically wrap the model with FSDP.
parallelism_config (Dict[str, Any], optional): The configuration dictionary for parallelism.

Attributes:
batch (types.Batch): The batch. This will be the entire batch during the :attr:`.Event.AFTER_DATALOADER`, or a
Expand Down Expand Up @@ -423,8 +496,7 @@ def __init__(

# Distributed training configs
deepspeed_config: Optional[Dict[str, Any]] = None,
fsdp_config: Optional[Dict[str, Any]] = None,
fsdp_auto_wrap: bool = True,
parallelism_config: Optional[Dict[str, Any]] = None,
):
self.rank_zero_seed = rank_zero_seed
self.model = model
Expand Down Expand Up @@ -468,20 +540,88 @@ def __init__(
self.profiler: Optional[Profiler] = None

self.deepspeed_config = deepspeed_config
self.fsdp_config = fsdp_config
self.fsdp_auto_wrap = fsdp_auto_wrap
parallelism_config = parallelism_config or {}
self.fsdp_config = parallelism_config.get('fsdp', None)
self.tp_config = parallelism_config.get('tp', None)

self._validate_parallelism_configs()

self.device_mesh: Optional[DeviceMesh] = _create_device_mesh(self.device, self.fsdp_config, self.tp_config)
if self.fsdp_config is not None and self.device_mesh is not None:
fsdp_mesh_dim_names = []
if self.device_mesh.mesh_dim_names is not None and ParallelismType.DATA_PARALLEL_REPLICATE.value in self.device_mesh.mesh_dim_names:
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_REPLICATE.value)
fsdp_mesh_dim_names.append(ParallelismType.DATA_PARALLEL_SHARD.value)
self.fsdp_config['device_mesh'] = self.device_mesh[tuple(fsdp_mesh_dim_names)] # type: ignore
if self.tp_config is not None and self.device_mesh is not None:
self.tp_config['device_mesh'] = self.device_mesh[ParallelismType.TENSOR_PARALLEL.value]

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]

self.train_metrics: Optional[Dict[str, Metric]] = {}
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _validate_parallelism_configs(self):
# Validate TP config
if self.tp_config is not None:
warnings.warn('Tensor parallelism (TP) is experimental and may change in future versions.', FutureWarning)
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.3.0'):
raise ValueError('Tensor parallelism (TP) requires torch>=2.3.0.')
if self.fsdp_config is None:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP to be enabled. '
'An empty `fsdp_config` can be specified to enable FSDP with '
'default settings. Additionally, PyTorch currently errors if FSDP '
'data_parallel_shard_degree is not at least 2.',
)
if not self.fsdp_config['use_orig_params']:
raise ValueError(
'Tensor parallelism (TP) currently requires FSDP with use_orig_params=True, '
'which is the default and recommended setting.',
)

# Load monolith rank0 only
if self.load_monolith_rank0_only:
assert fsdp_config is not None
if self.tp_config is not None:
raise ValueError('load_fsdp_monolith_rank0_only is not compatible with tensor parallelism (TP).')
assert self.fsdp_config is not None
error_message = ''
if fsdp_config['sync_module_states'] == False:
if self.fsdp_config['sync_module_states'] == False:
error_message += textwrap.dedent(
"load_monolith_rank0_only requires fsdp_config['sync_module_states'] to be True. "
"Either set fsdp_config['sync_module_states'] = True or set load_monolith_rank0_only = False. ",
)
# Broadcast rank 0 meta check to all ranks so error can be raised on all ranks
rank0_on_meta = 0
if dist.get_global_rank() == 0 and next(model.parameters()).device.type == 'meta':
if dist.get_global_rank() == 0 and next(self.model.parameters()).device.type == 'meta':
rank0_on_meta = 1
rank0_on_meta_tensor = self.device.tensor_to_device(torch.tensor([rank0_on_meta], dtype=torch.uint8))
dist.all_reduce(rank0_on_meta_tensor, reduce_operation='MAX')
Expand All @@ -494,10 +634,7 @@ def __init__(
if error_message != '':
raise ValueError(error_message)

self.sharded_ckpt_prefix_dir: Optional[str] = None
if self.fsdp_config is not None:
self.sharded_ckpt_prefix_dir = self.fsdp_config['sharded_ckpt_prefix_dir']

# Validate FSDP state dict type
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
if self.fsdp_state_dict_type == 'local':
raise ValueError(
Expand All @@ -521,39 +658,6 @@ def __init__(
),
)

# Set defaults for transient variables (to make pyright happy)
self.batch: Any = None
self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor()
self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor()

# These attributes will be serialized using .state_dict(), and loaded with .load_state_dict()
# All other attributes will not be serialized.
# For simplicity, omit the leading underscore for private attributes.
# For example, even though the optimizers are stored on the state
# as the "_optimizers" attribute, here we specify just "optimizers"
self.serialized_attributes = [
'model',
'optimizers',
'schedulers',
'algorithms',
'callbacks',
'scaler',
'timestamp',
'rank_zero_seed',
'train_metrics',
'eval_metrics',
'run_name',
'dataset_state',
]

self.train_metrics: Optional[Dict[str, Metric]] = {}
self.eval_metrics: Dict[str, Dict[str, Metric]] = {}
self.train_metric_values: Dict[str, float] = {}
self.eval_metric_values: Dict[str, float] = {}
self.total_loss_dict: Dict[str, float] = {}

self.metric_outputs: Dict[str, Any] = {}

def _dataset_of(self, dataloader: Optional[Union[Evaluator, DataSpec, DataLoader, Iterable]]) -> Optional[Dataset]:
"""Get the dataset contained by the given dataloader-like object.

Expand Down Expand Up @@ -794,12 +898,8 @@ def fsdp_sharded_state_dict_enabled(self):

@property
def fsdp_device_mesh(self):
if self.fsdp_enabled:
if not hasattr(self.model, 'model') or not hasattr(self.model.model, '_device_mesh'):
return None
return self.model.model._device_mesh
else:
return None
warnings.warn(VersionedDeprecationWarning('fsdp_device_mesh is deprecated. Use device_mesh instead.', '0.24'))
return self.device_mesh

@property
def load_fsdp_monolith_rank0_only(self):
Expand All @@ -814,8 +914,8 @@ def load_fsdp_monolith_rank0_only(self):
@property
def load_monolith_rank0_only(self):
return (
self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config['state_dict_type'] == 'full' and
self.fsdp_config['load_monolith_rank0_only'] == True
self.fsdp_config is not None and self.fsdp_config['auto_wrap'] and
self.fsdp_config['state_dict_type'] == 'full' and self.fsdp_config['load_monolith_rank0_only'] == True
)

def _get_integrations_state_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -1289,8 +1389,9 @@ def load_model_state(
if self.load_monolith_rank0_only:
assert self.fsdp_config is not None
log.info('Wrapping model with FSDP after loading model_state.')
from composer.trainer.dist_strategy import prepare_fsdp_module
with reproducibility.seed_context(self.rank_zero_seed):
from composer.distributed import prepare_fsdp_module

prepare_fsdp_module(
self.model,
self.optimizers,
Expand Down
25 changes: 25 additions & 0 deletions composer/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Distributed training."""

from composer.distributed.deepspeed import fix_batch_precision_for_deepspeed, parse_deepspeed_config
from composer.distributed.dist_strategy import (
DDPSyncStrategy,
ddp_sync_context,
prepare_ddp_module,
prepare_fsdp_module,
prepare_tp_module,
)
from composer.distributed.mosaic_fsdp import set_fsdp_default

__all__ = [
'fix_batch_precision_for_deepspeed',
'parse_deepspeed_config',
'DDPSyncStrategy',
'ddp_sync_context',
'prepare_ddp_module',
'prepare_fsdp_module',
'prepare_tp_module',
'set_fsdp_default',
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from composer.core import Batch, Precision, State
from composer.utils import dist, map_collection

__all__ = ['_fix_batch_precision_for_deepspeed', '_parse_deepspeed_config']
__all__ = ['fix_batch_precision_for_deepspeed', 'parse_deepspeed_config']


def _add_batch_config(config: Dict[str, Any], state: State):
Expand Down Expand Up @@ -105,7 +105,7 @@ def _add_precision_config(config: Dict[str, Any], state: State):
config['bf16'] = cast(Dict[str, Any], {'enabled': True})


def _parse_deepspeed_config(
def parse_deepspeed_config(
config: Dict[str, Any],
state: State,
) -> Dict[str, Any]:
Expand Down Expand Up @@ -160,7 +160,7 @@ def _convert_fp32_tensor_to_bf16(tensor: torch.Tensor):
return tensor


def _fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch:
def fix_batch_precision_for_deepspeed(batch: Batch, precision: Precision) -> Batch:
"""Ensures that a batch is properly formatted for DeepSpeed precisions, if active.

.. note:: Just because the precision is set to FP16 doesn't mean the entire batch can
Expand Down
Loading
Loading