diff --git a/composer/distributed/dist_strategy.py b/composer/distributed/dist_strategy.py index be81652881..dcf15ca9f2 100644 --- a/composer/distributed/dist_strategy.py +++ b/composer/distributed/dist_strategy.py @@ -171,9 +171,14 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info( for fsdp_name, param in fsdp_wrapped_named_params: unwrapped_name = clean_tensor_name(fsdp_name) - # need to have a 1:1 mapping between a fsdp param name and the non-wrapped vanilla param name - retrieved_group_num = non_wrapped_param_names_to_group_num[unwrapped_name] - group_num_to_optimizer_info[retrieved_group_num]['params'].append(param) + + # Since we are iterating over all model.named_parameters() after fsdp wrapping, we need to check + # if the parameter was included in the optimizer param_group pre fsdp wrapping, in order to support + # passing a subset of model params in the optimizer + if unwrapped_name in non_wrapped_param_names_to_group_num: + # Need to have a 1:1 mapping between a fsdp param name and the non-wrapped vanilla param name + retrieved_group_num = non_wrapped_param_names_to_group_num[unwrapped_name] + group_num_to_optimizer_info[retrieved_group_num]['params'].append(param) # return sorted optimizer info groups return [group_num_to_optimizer_info[num] for num in sorted(group_num_to_optimizer_info.keys())] @@ -181,9 +186,23 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info( def prepare_tp_module( model: torch.nn.Module, + optimizers: Optional[Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]], tp_config: TPConfig, ) -> None: """Prepare a module (assumed ComposerModel) for use with tensor parallel.""" + optimizers_tuple = ensure_tuple(optimizers) + if len(optimizers_tuple) != 1: + raise NotImplementedError(f'Only one optimizer is supported; found {len(optimizers_tuple)} optimizers') + + optim = optimizers_tuple[0] + if len(optim.param_groups) > 1: + raise RuntimeError('Multiple optimizer groups are not supported with tensor parallelism.',) + + if len(optim.param_groups[0]['params']) != len(list(model.parameters())): + raise ValueError( + 'Passing in a subset of model parameters to the optimizer is not supported with tensor parallelism.', + ) + from torch.distributed.tensor.parallel import parallelize_module device_mesh = tp_config.device_mesh @@ -247,11 +266,10 @@ def sync_hook(*args): raise RuntimeError('CUDA out of memory encountered on a different rank') # Necessary variables for optimizers with multiple param groups in FSDP - num_param_groups = None param_name_to_group_num = None - group_num_to_param_group_info = None + group_num_to_opt_group_info = None + single_param_group_opt_info = None - optimizer_specific_info = None if optimizers: optimizers_tuple = ensure_tuple(optimizers) if len(optimizers_tuple) != 1: @@ -261,38 +279,41 @@ def sync_hook(*args): # that will be recreated at the end of prepare_fsdp_module optim = optimizers_tuple[0] - num_param_groups = len(optim.param_groups) - if num_param_groups > 1: - if not fsdp_config.use_orig_params: - raise RuntimeError( - 'Multiple optimizer groups with FSDP are only supported with ' - 'use_orig_params=True.', - ) - # optimizer.param_groups do not contain parameter names which are needed - # to keep track of the different parameters in each group - # so we use the pointers between model.parameters() and model.named_parameters() - # to get the names of the parameters within optimizer.param_groups - param_pointer_to_param_name = {id(p): n for n, p in model.named_parameters()} + # Simplest case - single param group & all model params stored in optimizer + if len(optim.param_groups) == 1 and len(optim.param_groups[0]['params']) == len(list(model.parameters())): + single_param_group_opt_info = {k: v for k, v in optim.param_groups[0].items() if k != 'params'} + elif fsdp_config.use_orig_params: + # this code block stores information about param groups pre-fsdp wrapping in order to recreate them post-wrapping + # to do so, it relies on the ptrs of the model.parameters() in a model and the names of the params + # for this to work, use_orig_params=True, as we need the names of the params post-wrapping + # TP is not supported, as the underlying parameters in the model differ from the params in the param groups after being dtensorified + + ptr_to_param_name = {id(p): n for n, p in model.named_parameters()} param_name_to_group_num = {} - group_num_to_param_group_info = {} + group_num_to_opt_group_info = {} for group_num in range(len(optim.param_groups)): # Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory # group = optim.param_groups[group_num] for param_num in range(len(optim.param_groups[group_num]['params'])): - # Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory - # param = optim.param_groups[group_num]['params'][param_num] - param_name_to_group_num[param_pointer_to_param_name[id( - optim.param_groups[group_num]['params'][param_num], - )]] = group_num + param_ptr = id(optim.param_groups[group_num]['params'][param_num]) + if param_ptr not in ptr_to_param_name: + raise ValueError('The same model must be passed to the optimizer and trainer.') + param_name_to_group_num[ptr_to_param_name[param_ptr]] = group_num # this includes optimizer-specific values like lr, eps # this will be used as the kwargs for the optim param groups later optimizer_specific_group_info = { k: v for k, v in optim.param_groups[group_num].items() if k != 'params' } - group_num_to_param_group_info[group_num] = optimizer_specific_group_info + group_num_to_opt_group_info[group_num] = optimizer_specific_group_info else: - optimizer_specific_info = {k: v for k, v in optim.param_groups[0].items() if k != 'params'} + if len(optim.param_groups) > 1: + raise RuntimeError('Multiple optimizer groups with FSDP are not supported with use_orig_params=False.',) + + if len(optim.param_groups[0]['params']) != len(list(model.parameters())): + raise ValueError( + 'Passing in a subset of model parameters to the optimizer is not supported with use_orig_params=False.', + ) optim.param_groups.clear() optim.state.clear() @@ -711,19 +732,17 @@ def _check_fn(module: torch.nn.Module) -> bool: optim = ensure_tuple(optimizers)[0] optim.param_groups.clear() - assert num_param_groups is not None - if num_param_groups > 1: + if single_param_group_opt_info is not None: + single_param_group_opt_info.update({'params': list(model.parameters())}) + optim.add_param_group(single_param_group_opt_info) + elif fsdp_config.use_orig_params: assert param_name_to_group_num is not None - assert group_num_to_param_group_info is not None + assert group_num_to_opt_group_info is not None param_groups = _recreate_fsdp_param_groups_from_unwrapped_opt_info( model.named_parameters(), param_name_to_group_num, - group_num_to_param_group_info, + group_num_to_opt_group_info, ) for param_group in param_groups: optim.add_param_group(param_group) - else: - assert optimizer_specific_info is not None - optimizer_specific_info.update({'params': list(model.parameters())}) - optim.add_param_group(optimizer_specific_info) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c9c47cea03..56699e1fb3 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1661,6 +1661,7 @@ def __init__( with reproducibility.seed_context(self.state.rank_zero_seed): prepare_tp_module( model, + optimizers, self.state.tp_config, ) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index b077d22131..7f17641f1f 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import copy from unittest.mock import MagicMock import pytest @@ -234,6 +235,62 @@ def test_fsdp_process_group(world_size: int): trainer.fit() +@pytest.mark.gpu +@world_size(2) +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2'), + reason='FSDP use_orig_params requires torch 2.0 or higher', +) +def test_fsdp_subset_of_params_in_opt(world_size: int): + model = SimpleModel() + dataset = RandomClassificationDataset(size=10) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) + optimizer = torch.optim.SGD(model.fc1.parameters(), lr=0.01) + unwrapped_optimizer = copy.deepcopy(optimizer) + + trainer = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + parallelism_config={ + 'fsdp': { + 'use_orig_params': True, + }, + }, + max_duration='3ba', + ) + + with trainer.state.model.module.summon_full_params(trainer.state.model.module): + nb_parameters_before_fsdp = len(unwrapped_optimizer.param_groups[0]['params']) + nb_parameters_after_fsdp = len(trainer.state.optimizers[0].param_groups[0]['params']) + + assert nb_parameters_before_fsdp == nb_parameters_after_fsdp + + +@pytest.mark.gpu +@world_size(2) +def test_fsdp_subset_of_params_in_opt_without_orig_params(world_size: int): + model = SimpleModel() + dataset = RandomClassificationDataset(size=10) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) + optimizer = torch.optim.SGD(model.fc1.parameters(), lr=0.01) + + expected_error = 'Passing in a subset of model parameters to the optimizer is not supported with use_orig_params=False.' + + with pytest.raises(ValueError, match=expected_error): + _ = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + parallelism_config={ + 'fsdp': { + 'use_orig_params': False, + }, + }, + max_duration='3ba', + ) + + class SimpleMLP(ComposerModel): def __init__(self, num_features: int = 128, device: str = 'cuda'): diff --git a/tests/trainer/test_fsdp_param_groups.py b/tests/trainer/test_fsdp_param_groups.py index 7cbd52520e..8315ec4be2 100644 --- a/tests/trainer/test_fsdp_param_groups.py +++ b/tests/trainer/test_fsdp_param_groups.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import copy -import gc import pytest import torch @@ -11,7 +10,7 @@ from composer.trainer.trainer import Trainer from composer.utils import dist, misc -from tests.common import RandomClassificationDataset, SimpleModel, device, world_size +from tests.common import EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, device, world_size @pytest.mark.parametrize('mixed_precision', ['DEFAULT']) @@ -23,7 +22,7 @@ def test_fsdp_param_groups_without_orig_params(mixed_precision: str, device: str # Ensure that FSDP with 'use_orig_params=False' raises an exception when passing in an optimizer # with multiple param groups num_classes = 10 - model = SimpleModel(num_features=1, num_classes=num_classes) + model = SimpleModel(num_features=2, num_classes=num_classes) dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes) dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) @@ -31,7 +30,7 @@ def test_fsdp_param_groups_without_orig_params(mixed_precision: str, device: str param_groups = [{'params': param, 'lr': (0.1 + 0.1 * i)} for i, param in enumerate(model.parameters())] optimizer = torch.optim.SGD(param_groups, lr=0) - expected_error = 'Multiple optimizer groups with FSDP are only supported with use_orig_params=True.' + expected_error = 'Multiple optimizer groups with FSDP are not supported with use_orig_params=False.' with pytest.raises(RuntimeError, match=expected_error): _ = Trainer( @@ -48,7 +47,6 @@ def test_fsdp_param_groups_without_orig_params(mixed_precision: str, device: str max_duration='3ba', device=device, ) - gc.collect() @pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE']) @@ -122,3 +120,86 @@ def test_fsdp_with_param_groups(mixed_precision: str, device: str, reentrant: bo assert id(unwrapped_param) != id(wrapped_param) assert unwrapped_param_group['lr'] == wrapped_param_group['lr'] + + +@pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE']) +@pytest.mark.parametrize('reentrant', [True, False]) +@pytest.mark.filterwarnings('ignore::UserWarning') +@device('gpu') +@world_size(2) +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2'), + reason='FSDP use_orig_params requires torch 2.0 or higher', +) +def test_fsdp_with_param_groups_with_subset_of_params_in_opt( + mixed_precision: str, + device: str, + reentrant: bool, + world_size: int, +): + """ + Test whether an optimizer with param groups and a subset of model variables in the param groups is correctly fsdp wrapped. + """ + num_classes = 10 + + # Note that the EmbeddedWeightTiedModel is used instead of SimpleModel to ensure that some of the model parameters + # are excluded from the optimzier + model = EmbeddedWeightTiedModel(num_features=num_classes) + dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) + + # create a different group per parameter + param_groups = [{ + 'params': model.net1.fc1.parameters(), + 'lr': 0.1, + }, { + 'params': model.net2.fc2.parameters(), + 'lr': 0.5, + }] + + optimizer = torch.optim.SGD(param_groups) + unwrapped_optimizer = copy.deepcopy(optimizer) + + optimizer_groups_pre_fsdp = optimizer.param_groups + + trainer = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + parallelism_config={ + 'fsdp': { + 'activation_checkpointing_reentrant': reentrant, + 'mixed_precision': mixed_precision, + }, + }, + max_duration='3ba', + device=device, + ) + trainer.fit() + + assert misc.is_model_fsdp(trainer.state.model) + trainer_optimizer = trainer.state.optimizers[0] + assert len(trainer_optimizer.param_groups) > 1 + assert len(trainer_optimizer.param_groups) == len(optimizer_groups_pre_fsdp) + + with trainer.state.model.module.summon_full_params(trainer.state.model.module): # type: ignore + for unwrapped_param_group, wrapped_param_group in zip( + unwrapped_optimizer.param_groups, + trainer_optimizer.param_groups, + ): + + unwrapped_param_list = unwrapped_param_group['params'] + wrapped_param_list = wrapped_param_group['params'] + + assert len(unwrapped_param_list) == 1 + assert len(wrapped_param_list) == 1 + + unwrapped_param = unwrapped_param_list[0] + wrapped_param = wrapped_param_list[0] + + assert unwrapped_param.shape == wrapped_param.shape + + # the underlying tensor is different because it has been recreated when FSDP wraps the model + assert id(unwrapped_param) != id(wrapped_param) + + assert unwrapped_param_group['lr'] == wrapped_param_group['lr'] diff --git a/tests/trainer/test_tp.py b/tests/trainer/test_tp.py index bfee2e13c9..95b84d32c6 100644 --- a/tests/trainer/test_tp.py +++ b/tests/trainer/test_tp.py @@ -47,3 +47,78 @@ def test_tp_train(world_size: int): ) trainer.fit() + + +@pytest.mark.gpu +@world_size(4) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') +@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') +def test_tp_with_param_groups(world_size: int): + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel + + model = SimpleModel() + dataset = RandomClassificationDataset(size=8) + dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset)) + optimizer = torch.optim.SGD([{ + 'params': model.fc1.parameters(), + 'lr': 0.1, + }, { + 'params': model.fc2.parameters(), + 'lr': 0.5, + }]) + + layer_plan = { + 'fc1': ColwiseParallel(), + 'fc2': RowwiseParallel(), + } + + expected_error = 'Multiple optimizer groups are not supported with tensor parallelism.' + + with pytest.raises(RuntimeError, match=expected_error): + _ = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + parallelism_config={ + 'tp': { + 'layer_plan': layer_plan, + 'tensor_parallel_degree': 2, + }, + 'fsdp': {}, + }, + max_duration='3ba', + ) + + +@pytest.mark.gpu +@world_size(4) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.3'), reason='requires PyTorch 2.3+') +@pytest.mark.filterwarnings(r'ignore:.*\(TP\) is experimental.*:FutureWarning') +def test_tp_with_subset_of_params(world_size: int): + from torch.distributed.tensor.parallel import ColwiseParallel + + model = SimpleModel() + dataset = RandomClassificationDataset(size=8) + dataloader = DataLoader(dataset, batch_size=2, sampler=dist.get_sampler(dataset)) + optimizer = torch.optim.SGD(model.fc1.parameters(), lr=0.1) + + layer_plan = { + 'fc1': ColwiseParallel(), + } + + expected_error = 'Passing in a subset of model parameters to the optimizer is not supported with tensor parallelism.' + + with pytest.raises(ValueError, match=expected_error): + _ = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + parallelism_config={ + 'tp': { + 'layer_plan': layer_plan, + 'tensor_parallel_degree': 2, + }, + 'fsdp': {}, + }, + max_duration='3ba', + )