Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when subset of model parameters is passed into optimizer with FSDP #3502

Merged
merged 16 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 38 additions & 29 deletions composer/distributed/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,14 @@ def _recreate_fsdp_param_groups_from_unwrapped_opt_info(
for fsdp_name, param in fsdp_wrapped_named_params:

unwrapped_name = clean_tensor_name(fsdp_name)
# need to have a 1:1 mapping between a fsdp param name and the non-wrapped vanilla param name
retrieved_group_num = non_wrapped_param_names_to_group_num[unwrapped_name]
group_num_to_optimizer_info[retrieved_group_num]['params'].append(param)

# since we are iterating over all model.named_parameters() after fsdp wrapping, we need to check
# if the parameter was included in the optimizer param_group pre fsdp wrapping, in order to support
# passing a subset of model params in the optimizer
if unwrapped_name in non_wrapped_param_names_to_group_num:
# need to have a 1:1 mapping between a fsdp param name and the non-wrapped vanilla param name
retrieved_group_num = non_wrapped_param_names_to_group_num[unwrapped_name]
group_num_to_optimizer_info[retrieved_group_num]['params'].append(param)

# return sorted optimizer info groups
return [group_num_to_optimizer_info[num] for num in sorted(group_num_to_optimizer_info.keys())]
Expand Down Expand Up @@ -202,6 +207,7 @@ def prepare_fsdp_module(
precision: Precision,
device: Device,
auto_microbatching: bool,
using_tp: bool = False,
te_rng_seed: int = 1234,
) -> None:
"""Prepare a module (assumed ComposerModel) and optimizer for use with :class:`torch.distributed.fsdp.FullyShardedDataParallel`.
Expand All @@ -213,6 +219,7 @@ def prepare_fsdp_module(
precision: (Precision): The precision being used by the Trainer, used to fill in defaults for FSDP `mixed_precision` settings.
device (Device): The device being used by the Trainer.
auto_microbatching (bool, optional): Whether or not auto microbatching is enabled.
using_tp (bool, optional): Whether the model has been wrapped with Tensor Parallelism, in which case only a single optimizer param group is supported.
te_rng_seed(int): The seed to use for the Transformer Engine activation checkpointing RNG. Defaults to 1234.
"""
# Check sync_module_states is True for mixed initialization or HSDP
Expand Down Expand Up @@ -247,11 +254,10 @@ def sync_hook(*args):
raise RuntimeError('CUDA out of memory encountered on a different rank')

# Necessary variables for optimizers with multiple param groups in FSDP
num_param_groups = None
param_name_to_group_num = None
group_num_to_param_group_info = None

group_num_to_opt_group_info = None
optimizer_specific_info = None

if optimizers:
optimizers_tuple = ensure_tuple(optimizers)
if len(optimizers_tuple) != 1:
Expand All @@ -261,37 +267,41 @@ def sync_hook(*args):
# that will be recreated at the end of prepare_fsdp_module
optim = optimizers_tuple[0]

num_param_groups = len(optim.param_groups)
if num_param_groups > 1:
if not fsdp_config.use_orig_params:
raise RuntimeError(
'Multiple optimizer groups with FSDP are only supported with '
'use_orig_params=True.',
)
# optimizer.param_groups do not contain parameter names which are needed
# to keep track of the different parameters in each group
# so we use the pointers between model.parameters() and model.named_parameters()
# to get the names of the parameters within optimizer.param_groups
param_pointer_to_param_name = {id(p): n for n, p in model.named_parameters()}
if fsdp_config.use_orig_params and not using_tp:
# this code block stores information about param groups pre-fsdp wrapping in order to recreate them post-wrapping
# to do so, it relies on the ptrs of the model.parameters() in a model and the names of the params
# for this to work, use_orig_params=True, as we need the names of the params post-wrapping
# TP is not supported, as the underlying parameters in the model differ from the params in the param groups after being dtensorified

ptr_to_param_name = {id(p): n for n, p in model.named_parameters()}
param_name_to_group_num = {}
group_num_to_param_group_info = {}
group_num_to_opt_group_info = {}
for group_num in range(len(optim.param_groups)):
# Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory
# group = optim.param_groups[group_num]
for param_num in range(len(optim.param_groups[group_num]['params'])):
# Need to in-line to avoid a reference which causes FSDP to allocate extra GPU memory
# param = optim.param_groups[group_num]['params'][param_num]
param_name_to_group_num[param_pointer_to_param_name[id(
optim.param_groups[group_num]['params'][param_num],
)]] = group_num
param_ptr = id(optim.param_groups[group_num]['params'][param_num])
if param_ptr not in ptr_to_param_name:
raise ValueError('The same model must be passed to the optimizer and trainer.')
param_name_to_group_num[ptr_to_param_name[param_ptr]] = group_num

# this includes optimizer-specific values like lr, eps
# this will be used as the kwargs for the optim param groups later
optimizer_specific_group_info = {
k: v for k, v in optim.param_groups[group_num].items() if k != 'params'
}
group_num_to_param_group_info[group_num] = optimizer_specific_group_info
group_num_to_opt_group_info[group_num] = optimizer_specific_group_info
else:
if len(optim.param_groups) > 1:
raise RuntimeError(
'Multiple optimizer groups with FSDP are not supported with tensor parallelism and/or use_orig_params=False.',
)

if len(optim.param_groups[0]['params']) != len(list(model.parameters())):
raise ValueError(
'Passing in a subset of model parameters to the optimizer is not supported with tensor parallelism and/or use_orig_params=False.',
)

optimizer_specific_info = {k: v for k, v in optim.param_groups[0].items() if k != 'params'}

optim.param_groups.clear()
Expand Down Expand Up @@ -711,15 +721,14 @@ def _check_fn(module: torch.nn.Module) -> bool:
optim = ensure_tuple(optimizers)[0]
optim.param_groups.clear()

assert num_param_groups is not None
if num_param_groups > 1:
if fsdp_config.use_orig_params and not using_tp:
assert param_name_to_group_num is not None
assert group_num_to_param_group_info is not None
assert group_num_to_opt_group_info is not None

param_groups = _recreate_fsdp_param_groups_from_unwrapped_opt_info(
model.named_parameters(),
param_name_to_group_num,
group_num_to_param_group_info,
group_num_to_opt_group_info,
)
for param_group in param_groups:
optim.add_param_group(param_group)
Expand Down
11 changes: 10 additions & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,7 @@ def __init__(
precision,
device,
auto_microbatching,
self.state.tp_config is not None,
self.state.seed,
)

Expand Down Expand Up @@ -1838,7 +1839,15 @@ def __init__(
):
# Init with globally fixed seed so all HSDP replicas have the same initial weights
with reproducibility.seed_context(self.state.rank_zero_seed):
prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching)
prepare_fsdp_module(
model,
optimizers,
self.state.fsdp_config,
precision,
device,
auto_microbatching,
self.state.tp_config is not None,
)

self.engine.run_event(Event.AFTER_LOAD)

Expand Down
59 changes: 59 additions & 0 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import copy
import gc
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -234,6 +236,63 @@ def test_fsdp_process_group(world_size: int):
trainer.fit()


@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse('2'),
reason='FSDP use_orig_params requires torch 2.0 or higher',
)
def test_fsdp_subset_of_params_in_opt(world_size: int):
model = SimpleModel()
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.fc1.parameters(), lr=0.01)
unwrapped_optimizer = copy.deepcopy(optimizer)

trainer = Trainer(
model=model,
optimizers=optimizer,
train_dataloader=dataloader,
parallelism_config={
'fsdp': {
'use_orig_params': True,
},
},
max_duration='3ba',
)

with trainer.state.model.module.summon_full_params(trainer.state.model.module):
nb_parameters_before_fsdp = len(unwrapped_optimizer.param_groups[0]['params'])
nb_parameters_after_fsdp = len(trainer.state.optimizers[0].param_groups[0]['params'])

assert nb_parameters_before_fsdp == nb_parameters_after_fsdp


@pytest.mark.gpu
@world_size(2)
def test_fsdp_subset_of_params_in_opt_without_orig_params(world_size: int):
model = SimpleModel()
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.fc1.parameters(), lr=0.01)

expected_error = 'Passing in a subset of model parameters to the optimizer is not supported with tensor parallelism and/or use_orig_params=False.'

with pytest.raises(ValueError, match=expected_error):
_ = Trainer(
model=model,
optimizers=optimizer,
train_dataloader=dataloader,
parallelism_config={
'fsdp': {
'use_orig_params': False,
},
},
max_duration='3ba',
)
gc.collect()


class SimpleMLP(ComposerModel):

def __init__(self, num_features: int = 128, device: str = 'cuda'):
Expand Down
89 changes: 86 additions & 3 deletions tests/trainer/test_fsdp_param_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from composer.trainer.trainer import Trainer
from composer.utils import dist, misc
from tests.common import RandomClassificationDataset, SimpleModel, device, world_size
from tests.common import EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, device, world_size


@pytest.mark.parametrize('mixed_precision', ['DEFAULT'])
Expand All @@ -23,15 +23,15 @@ def test_fsdp_param_groups_without_orig_params(mixed_precision: str, device: str
# Ensure that FSDP with 'use_orig_params=False' raises an exception when passing in an optimizer
# with multiple param groups
num_classes = 10
model = SimpleModel(num_features=1, num_classes=num_classes)
model = SimpleModel(num_features=2, num_classes=num_classes)
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))

# Create a different parameter per group
param_groups = [{'params': param, 'lr': (0.1 + 0.1 * i)} for i, param in enumerate(model.parameters())]
optimizer = torch.optim.SGD(param_groups, lr=0)

expected_error = 'Multiple optimizer groups with FSDP are only supported with use_orig_params=True.'
expected_error = 'Multiple optimizer groups with FSDP are not supported with tensor parallelism and/or use_orig_params=False.'

with pytest.raises(RuntimeError, match=expected_error):
_ = Trainer(
Expand Down Expand Up @@ -122,3 +122,86 @@ def test_fsdp_with_param_groups(mixed_precision: str, device: str, reentrant: bo
assert id(unwrapped_param) != id(wrapped_param)

assert unwrapped_param_group['lr'] == wrapped_param_group['lr']


@pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE'])
@pytest.mark.parametrize('reentrant', [True, False])
@pytest.mark.filterwarnings('ignore::UserWarning')
@device('gpu')
@world_size(2)
@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse('2'),
reason='FSDP use_orig_params requires torch 2.0 or higher',
)
def test_fsdp_with_param_groups_with_subset_of_params_in_opt(
mixed_precision: str,
device: str,
reentrant: bool,
world_size: int,
):
"""
Test whether an optimizer with param groups and a subset of model variables in the param groups is correctly fsdp wrapped.
"""
num_classes = 10

# Note that the EmbeddedWeightTiedModel is used instead of SimpleModel to ensure that some of the model parameters
# are excluded from the optimzier
model = EmbeddedWeightTiedModel(num_features=num_classes)
dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))

# create a different group per parameter
param_groups = [{
'params': model.net1.fc1.parameters(),
'lr': 0.1,
}, {
'params': model.net2.fc2.parameters(),
'lr': 0.5,
}]

optimizer = torch.optim.SGD(param_groups)
unwrapped_optimizer = copy.deepcopy(optimizer)

optimizer_groups_pre_fsdp = optimizer.param_groups

trainer = Trainer(
model=model,
optimizers=optimizer,
train_dataloader=dataloader,
parallelism_config={
'fsdp': {
'activation_checkpointing_reentrant': reentrant,
'mixed_precision': mixed_precision,
},
},
max_duration='3ba',
device=device,
)
trainer.fit()

assert misc.is_model_fsdp(trainer.state.model)
trainer_optimizer = trainer.state.optimizers[0]
assert len(trainer_optimizer.param_groups) > 1
assert len(trainer_optimizer.param_groups) == len(optimizer_groups_pre_fsdp)

with trainer.state.model.module.summon_full_params(trainer.state.model.module): # type: ignore
for unwrapped_param_group, wrapped_param_group in zip(
unwrapped_optimizer.param_groups,
trainer_optimizer.param_groups,
):

unwrapped_param_list = unwrapped_param_group['params']
wrapped_param_list = wrapped_param_group['params']

assert len(unwrapped_param_list) == 1
assert len(wrapped_param_list) == 1

unwrapped_param = unwrapped_param_list[0]
wrapped_param = wrapped_param_list[0]

assert unwrapped_param.shape == wrapped_param.shape

# the underlying tensor is different because it has been recreated when FSDP wraps the model
assert id(unwrapped_param) != id(wrapped_param)

assert unwrapped_param_group['lr'] == wrapped_param_group['lr']
Loading