From 903845d0b81d5e8aa785605a2c69662d0082119b Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 17 May 2024 14:20:28 -0700 Subject: [PATCH 01/15] first commit untested --- composer/checkpoint/state_dict.py | 101 ++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 5417188466..41f546418e 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -152,3 +152,104 @@ def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_st with FSDP.state_dict_type(model, state_dict_type=state_dict_type, state_dict_config=state_dict_config): model_state_dict = model.state_dict() return model_state_dict + +def _get_optim_state_dict_with_fsdp_context_manager(model: nn.Module, + optimizer: torch.optim.Optimizer, + sharded_state_dict: bool, + cpu_offload: bool) -> Dict[str, Any]: + """Get the optimizer state dict with the FSDP context manager. + + Args: + model: The model containing the parameters that the optimizer is optimizing. + optimizer: The optimizer to get the state dict from. + sharded_state_dict: Whether the optimizer state dict should be sharded or not. If True, every rank returns the state dict of its shards. + If False, then rank 0 returns the state dict of the entire optimizer. + cpu_offload: Whether to offload the state dict to CPU. + + Returns: + The state dict of the optimizer. + + """ + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullStateDictConfig, + ShardedStateDictConfig, + StateDictType, + FullOptimStateDictConfig, + ShardedOptimStateDictConfig, + ) + state_dict_type = StateDictType.SHARDED_STATE_DICT if sharded_state_dict else StateDictType.FULL_STATE_DICT + + state_dict_config = ShardedStateDictConfig(offload_to_cpu=cpu_offload, + ) if sharded_state_dict else FullStateDictConfig( + rank0_only=True, + offload_to_cpu=cpu_offload, + ) + optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=cpu_offload) if sharded_state_dict else FullOptimStateDictConfig( + rank0_only=True, + offload_to_cpu=cpu_offload) + with FSDP.state_dict_type(model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config): + optim_state_dict = FSDP.optim_state_dict(model, optimizer) + return optim_state_dict + + +def get_optim_state_dict( + model: Union[ComposerModel, nn.Module], + optimizer: torch.optim.Optimizer, + sharded_state_dict: bool, + precision: str, + include_keys: Optional[Union[str, Sequence[str]]] = None, + ignore_keys: Optional[Union[str, Sequence[str]]] = None, + cpu_offload: Optional[bool] = None, + ) -> Dict[str, Any]: + """Generate the state dict of the optimizer. + + Args: + model: The model containing the parameters that the optimizer is optimizing. + optimizer: The optimizer to get the state dict from. + sharded: Whether the optimizer is sharded or not. If True, every rank returns the state dict of its shards. + If False, then rank 0 returns the state dict of the entire optimizer. + precision: The precision of the optimizer. + include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None. + ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None. + cpu_offload: Whether to offload the state dict to CPU. If None, it is set to True if FSDP is enabled with non-sharded state dict and False otherwise. + + Returns: + The state dict of the optimizer. + """ + if include_keys is not None and ignore_keys is not None: + raise ValueError(f'Both {include_keys=} and {ignore_keys=} cannot be non-None.') + + is_fsdp = _is_model_fsdp(model) + if not is_fsdp and sharded_state_dict: + raise ValueError('Sharded optim state dict can only be generated for FSDP models.') + + cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict) + log.debug('Extracting optim state dict') + if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized(): + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict + log.debug('Calling torch get_optimizer_state_dict...') + optim_state_dict = get_optimizer_state_dict( + model=model, + optimizers=optimizer, + submodules=None, # We extract submodules below + options=StateDictOptions( + full_state_dict=not sharded_state_dict, + cpu_offload=cpu_offload, + ), + ) + else: + if is_fsdp: + log.debug('Calling legacy FSDP context manager to get optim state dict...') + optim_state_dict = _get_optim_state_dict_with_fsdp_context_manager(model, + optimizer, + sharded_state_dict, + cpu_offload) + else: + optim_state_dict = optimizer.state_dict() + + + optim_state_dict = _cast_state_dict_to_precision(state_dict=optim_state_dict, + precision=precision) From f1c434bee39d9cc295e18fbbbe245a4b0a9d4421 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Wed, 22 May 2024 22:46:08 +0000 Subject: [PATCH 02/15] unsharded unit tests --- composer/checkpoint/__init__.py | 4 +- composer/checkpoint/state_dict.py | 45 ++++++++-- tests/checkpoint/test_state_dict.py | 130 +++++++++++++++++++++++++++- tests/common/compare.py | 14 +-- tests/common/models.py | 6 +- 5 files changed, 182 insertions(+), 17 deletions(-) diff --git a/composer/checkpoint/__init__.py b/composer/checkpoint/__init__.py index be9c380c2d..fb8e8fc19d 100644 --- a/composer/checkpoint/__init__.py +++ b/composer/checkpoint/__init__.py @@ -3,8 +3,10 @@ """Module for checkpointing API.""" -from composer.checkpoint.state_dict import get_model_state_dict +from composer.checkpoint.state_dict import (get_model_state_dict, + get_optim_state_dict) __all__ = [ 'get_model_state_dict', + 'get_optim_state_dict' ] diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 41f546418e..1b8289f710 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -18,6 +18,8 @@ log = logging.getLogger(__name__) +__all__ = ['get_model_state_dict', 'get_optim_state_dict'] + def get_model_state_dict( model: Union[ComposerModel, nn.Module], @@ -198,8 +200,8 @@ def _get_optim_state_dict_with_fsdp_context_manager(model: nn.Module, def get_optim_state_dict( model: Union[ComposerModel, nn.Module], optimizer: torch.optim.Optimizer, - sharded_state_dict: bool, - precision: str, + sharded_state_dict: bool = False, + precision: str = 'fp32', include_keys: Optional[Union[str, Sequence[str]]] = None, ignore_keys: Optional[Union[str, Sequence[str]]] = None, cpu_offload: Optional[bool] = None, @@ -249,7 +251,40 @@ def get_optim_state_dict( cpu_offload) else: optim_state_dict = optimizer.state_dict() - - optim_state_dict = _cast_state_dict_to_precision(state_dict=optim_state_dict, - precision=precision) + if ignore_keys is not None: + optim_state_dict = _remove_keys_from_optim_state_dict(optim_state_dict, model, ignore_keys) + if include_keys is not None: + optim_state_dict = _extract_keys_from_optim_state_dict(optim_state_dict, model, include_keys) + + for param_ind, param_state_dict in optim_state_dict['state'].items(): + optim_state_dict['state'][param_ind] = _cast_state_dict_to_precision(param_state_dict, precision) + return optim_state_dict + + +def _remove_keys_from_optim_state_dict(optim_state_dict: Dict[str, Any], model: Union[ComposerModel, nn.Module], ignore_keys: Union[str, Sequence[str]]): + if isinstance(ignore_keys, str): + ignore_keys = [ignore_keys] + + # optim_state_dict['state'] is a dictionary mapping the param_ind (0,1,2,..., len(model.parameters())-1) + # to the optimizer state for that parameter e.g. 'step', 'exp_avg', 'exp_avg_sq'. + # The param_ind ordering is determined by passing model.parameters() + # to the optimizer. The underlying generator for model.parameters() is model.named_parameters() + # so we need to use model.named_parameters() instead of model.state_dict().keys() to match fqn to ind correctly. + param_inds = list(optim_state_dict['state'].keys()) + for param_ind, (param_fqn, _) in zip(param_inds, model.named_parameters()): + if any([fnmatch.fnmatch(param_fqn, ignore_key) for ignore_key in ignore_keys]): + optim_state_dict['state'].pop(param_ind) + + return optim_state_dict + +def _extract_keys_from_optim_state_dict(optim_state_dict: Dict[str, Any], model: Union[ComposerModel, nn.Module], include_keys: Union[str, Sequence[str]]): + if isinstance(include_keys, str): + include_keys = [include_keys] + param_inds = list(optim_state_dict['state'].keys()) + # See comment in _remove_keys_from_optim_state_dict. + for param_ind, (param_fqn, _) in zip(param_inds, model.named_parameters()): + if not any([fnmatch.fnmatch(param_fqn, include_key) for include_key in include_keys]): + optim_state_dict['state'].pop(param_ind) + + return optim_state_dict \ No newline at end of file diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 8b40c83bcc..73abfaa204 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -8,12 +8,13 @@ from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from composer.checkpoint import get_model_state_dict +from composer.checkpoint import get_model_state_dict, get_optim_state_dict from composer.utils import dist from tests.common.compare import deep_compare from tests.common.markers import world_size from tests.common.models import EvenSimplerMLP, SimpleComposerMLP - +from torch.optim import adam +import fnmatch @pytest.mark.gpu @pytest.mark.parametrize('use_composer_model', [True, False]) @@ -230,3 +231,128 @@ def test_get_model_state_dict_precision_unsharded_model(precision: str, use_comp ) for tens in model_state_dict.values(): assert tens.dtype == precision + +def _init_model_and_optimizer(use_composer_model: bool, num_classes=3, batch_size = 5, num_features = 8, take_step=True): + + if use_composer_model: + model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device='cuda') + loss_fn = model._loss_fn + else: + model = EvenSimplerMLP(num_features=num_features, + num_out_features=num_classes, + device='cuda') + loss_fn = torch.nn.CrossEntropyLoss() + + inputs = torch.randn(batch_size, num_features, device='cuda') + targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device='cuda', dtype=torch.long) + batch = (inputs, targets) if use_composer_model else inputs + outputs = model(batch) + optimizer = adam.Adam(model.parameters()) + loss = loss_fn(outputs, targets) + loss.backward() + if take_step: + optimizer.step() + return model, optimizer + + +@pytest.mark.gpu +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=False) + + # Before ever taking a step it should be empty. + optim_state_dict = get_optim_state_dict(model, optimizer) + assert optim_state_dict['state'] == optimizer.state == {} + + optimizer.step() + optim_state_dict = get_optim_state_dict(model, optimizer) + + # Dict mapping parameter index to optimizer state for that parameter. + osd_state = optim_state_dict['state'] + # Dict mapping parameter itself to optimizer state for that parameter. + optim_state = optimizer.state + + # Make sure optimizer state is the same between the state dict and the optimizer object. + for osd_param_state, opt_param_state in zip(osd_state.values(), optim_state.values()): + deep_compare(osd_param_state, opt_param_state) + + # Make sure the optimizer state in the state dict is the same shape as the parameter it corresponds to. + params = list(model.parameters()) + for param_ind, param_state in osd_state.items(): + param = params[param_ind] + assert param.shape == param_state['exp_avg'].shape + assert param.shape == param_state['exp_avg_sq'].shape + + # Make sure param groups between the state dict and the optimizer object are the same. + for osd_group, opt_group in zip(optim_state_dict['param_groups'], optimizer.param_groups): + # Only params should differ between the two. + # * in the optimizer state dict params will be indices into the model's parameters list. + # * in the optimizer object params will be the actual parameter tensors. + deep_compare(osd_group, opt_group, ignore_keys=['params']) + + +@pytest.mark.gpu +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_state_dict_include(use_composer_model: bool): + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, + take_step=True) + fqns = [param_fqn for param_fqn, _ in model.named_parameters()] + include_keys=['module.0.weight'] + optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) + expected_optim_state_keys = [] + for fqn in fqns: + if any([fnmatch.fnmatch(fqn, include_key) for include_key in include_keys]): + expected_optim_state_keys.append(fqns.index(fqn)) + assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) + + include_keys=['module.2*'] + optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) + expected_optim_state_keys = [] + for fqn in fqns: + if any([fnmatch.fnmatch(fqn, include_key) for include_key in include_keys]): + expected_optim_state_keys.append(fqns.index(fqn)) + assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) + + +@pytest.mark.gpu +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_state_dict_ignore(use_composer_model: bool): + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, + take_step=True) + fqns = [param_fqn for param_fqn, _ in model.named_parameters()] + ignore_keys=['module.0*'] + optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) + expected_optim_state_keys = [] + for fqn in fqns: + if not any([fnmatch.fnmatch(fqn, ignore_key) for ignore_key in ignore_keys]): + expected_optim_state_keys.append(fqns.index(fqn)) + assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) + + ignore_keys=['module.2.weight'] + optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) + expected_optim_state_keys = [] + for fqn in fqns: + if not any([fnmatch.fnmatch(fqn, ignore_key) for ignore_key in ignore_keys]): + expected_optim_state_keys.append(fqns.index(fqn)) + assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) + + +@pytest.mark.gpu +@pytest.mark.parametrize( + 'precision', + [ + torch.float32, + torch.float16, + torch.bfloat16, + ], +) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_model_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool): + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, + take_step=True) + optim_state_dict = get_optim_state_dict(model, optimizer, precision=precision) + for param_state in optim_state_dict['state'].values(): + assert param_state['exp_avg'].dtype == precision + assert param_state['exp_avg_sq'].dtype == precision + + diff --git a/tests/common/compare.py b/tests/common/compare.py index 942fa67504..72ac496d32 100644 --- a/tests/common/compare.py +++ b/tests/common/compare.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import datetime -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple, Union, Optional import numpy as np import torch @@ -12,7 +12,7 @@ from composer.core.time import TimeUnit -def deep_compare(item1: Any, item2: Any, atol: float = 0.0, rtol: float = 0.0): +def deep_compare(item1: Any, item2: Any, atol: float = 0.0, rtol: float = 0.0, ignore_keys: Optional[List[str]] = None): """Compare two items recursively. Supports dicts, lists, tuples, tensors, numpy arrays, Composer Time objects, and callables. Args: @@ -21,10 +21,10 @@ def deep_compare(item1: Any, item2: Any, atol: float = 0.0, rtol: float = 0.0): atol (bool): Atol tolerance for torch tensors and numpy arrays (default: 0.0) rtol (float): Rtol tolerance for torch tensors and numpy arrays (default: 0.0) """ - return _check_item(item1, item2, path='', atol=atol, rtol=rtol) + return _check_item(item1, item2, path='', atol=atol, rtol=rtol, ignore_keys=ignore_keys) -def _check_item(item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: float = 0.0): +def _check_item(item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: float = 0.0, ignore_keys: Optional[List[str]] = None): if item1 is None: assert item2 is None, f'{path} differs: {item1} != {item2}' return @@ -45,7 +45,7 @@ def _check_item(item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: floa return if isinstance(item1, dict): assert isinstance(item2, dict), f'{path} differs: {item1} != {item2}' - _check_dict_recursively(item1, item2, path, atol=atol, rtol=rtol) + _check_dict_recursively(item1, item2, path, atol=atol, rtol=rtol, ignore_keys=ignore_keys) return if isinstance(item1, (tuple, list)): assert isinstance(item2, type(item1)), f'{path} differs: {item1} != {item2}' @@ -89,9 +89,11 @@ def _check_list_recursively( _check_item(item1, item2, path=f'{path}/{i}', atol=atol, rtol=rtol) -def _check_dict_recursively(dict1: Dict[str, Any], dict2: Dict[str, Any], path: str, atol: float, rtol: float): +def _check_dict_recursively(dict1: Dict[str, Any], dict2: Dict[str, Any], path: str, atol: float, rtol: float, ignore_keys: Optional[List[str]] = None): assert len(dict1) == len(dict2), f'{path} differs: {dict1} != {dict2}' for k, val1 in dict1.items(): + if ignore_keys is not None and k in ignore_keys: + continue val2 = dict2[k] # special case fused optimizer to allow comparing a GPU checkpoint with a CPU checkpoint diff --git a/tests/common/models.py b/tests/common/models.py index ea3546ad1a..b67f0e41df 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -117,10 +117,10 @@ def forward(self, x): # are not submodules of EvenSimplerMLP, like they are in SimpleMLP. class EvenSimplerMLP(torch.nn.Module): - def __init__(self, num_features: int, device: str): + def __init__(self, num_features: int, device: str, num_out_features: int=3): super().__init__() fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + fc2 = torch.nn.Linear(num_features, num_out_features, device=device, bias=False) self.module = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) @@ -134,7 +134,7 @@ class SimpleComposerMLP(ComposerClassifier): def __init__(self, num_features: int, device: str, num_classes: int = 3): fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) - fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + fc2 = torch.nn.Linear(num_features, num_classes, device=device, bias=False) net = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2) super().__init__(num_classes=num_classes, module=net) From b485cf441e94b9eb9d9bfac7af554980f5b20e6f Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Wed, 22 May 2024 15:50:22 -0700 Subject: [PATCH 03/15] pre-commit --- composer/checkpoint/__init__.py | 5 +- composer/checkpoint/state_dict.py | 76 ++++++++++++++++------------- tests/checkpoint/test_state_dict.py | 45 ++++++++--------- tests/common/compare.py | 15 ++++-- tests/common/models.py | 2 +- 5 files changed, 76 insertions(+), 67 deletions(-) diff --git a/composer/checkpoint/__init__.py b/composer/checkpoint/__init__.py index fb8e8fc19d..509b4f8cea 100644 --- a/composer/checkpoint/__init__.py +++ b/composer/checkpoint/__init__.py @@ -3,10 +3,9 @@ """Module for checkpointing API.""" -from composer.checkpoint.state_dict import (get_model_state_dict, - get_optim_state_dict) +from composer.checkpoint.state_dict import get_model_state_dict, get_optim_state_dict __all__ = [ 'get_model_state_dict', - 'get_optim_state_dict' + 'get_optim_state_dict', ] diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 1b8289f710..7f98b69abb 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -155,10 +155,10 @@ def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_st model_state_dict = model.state_dict() return model_state_dict -def _get_optim_state_dict_with_fsdp_context_manager(model: nn.Module, - optimizer: torch.optim.Optimizer, - sharded_state_dict: bool, - cpu_offload: bool) -> Dict[str, Any]: + +def _get_optim_state_dict_with_fsdp_context_manager( + model: nn.Module, optimizer: torch.optim.Optimizer, sharded_state_dict: bool, cpu_offload: bool +) -> Dict[str, Any]: """Get the optimizer state dict with the FSDP context manager. Args: @@ -167,17 +167,17 @@ def _get_optim_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_state_dict: Whether the optimizer state dict should be sharded or not. If True, every rank returns the state dict of its shards. If False, then rank 0 returns the state dict of the entire optimizer. cpu_offload: Whether to offload the state dict to CPU. - + Returns: The state dict of the optimizer. """ from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, FullStateDictConfig, + ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType, - FullOptimStateDictConfig, - ShardedOptimStateDictConfig, ) state_dict_type = StateDictType.SHARDED_STATE_DICT if sharded_state_dict else StateDictType.FULL_STATE_DICT @@ -186,26 +186,28 @@ def _get_optim_state_dict_with_fsdp_context_manager(model: nn.Module, rank0_only=True, offload_to_cpu=cpu_offload, ) - optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=cpu_offload) if sharded_state_dict else FullOptimStateDictConfig( - rank0_only=True, - offload_to_cpu=cpu_offload) - with FSDP.state_dict_type(model, - state_dict_type=state_dict_type, - state_dict_config=state_dict_config, - optim_state_dict_config=optim_state_dict_config): + optim_state_dict_config = ShardedOptimStateDictConfig( + offload_to_cpu=cpu_offload + ) if sharded_state_dict else FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=cpu_offload) + with FSDP.state_dict_type( + model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config + ): optim_state_dict = FSDP.optim_state_dict(model, optimizer) - return optim_state_dict + return optim_state_dict def get_optim_state_dict( - model: Union[ComposerModel, nn.Module], - optimizer: torch.optim.Optimizer, - sharded_state_dict: bool = False, - precision: str = 'fp32', - include_keys: Optional[Union[str, Sequence[str]]] = None, - ignore_keys: Optional[Union[str, Sequence[str]]] = None, - cpu_offload: Optional[bool] = None, - ) -> Dict[str, Any]: + model: Union[ComposerModel, nn.Module], + optimizer: torch.optim.Optimizer, + sharded_state_dict: bool = False, + precision: str = 'fp32', + include_keys: Optional[Union[str, Sequence[str]]] = None, + ignore_keys: Optional[Union[str, Sequence[str]]] = None, + cpu_offload: Optional[bool] = None, +) -> Dict[str, Any]: """Generate the state dict of the optimizer. Args: @@ -217,7 +219,7 @@ def get_optim_state_dict( include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None. ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None. cpu_offload: Whether to offload the state dict to CPU. If None, it is set to True if FSDP is enabled with non-sharded state dict and False otherwise. - + Returns: The state dict of the optimizer. """ @@ -227,7 +229,7 @@ def get_optim_state_dict( is_fsdp = _is_model_fsdp(model) if not is_fsdp and sharded_state_dict: raise ValueError('Sharded optim state dict can only be generated for FSDP models.') - + cpu_offload = cpu_offload if cpu_offload is not None else (is_fsdp and not sharded_state_dict) log.debug('Extracting optim state dict') if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized(): @@ -245,13 +247,12 @@ def get_optim_state_dict( else: if is_fsdp: log.debug('Calling legacy FSDP context manager to get optim state dict...') - optim_state_dict = _get_optim_state_dict_with_fsdp_context_manager(model, - optimizer, - sharded_state_dict, - cpu_offload) + optim_state_dict = _get_optim_state_dict_with_fsdp_context_manager( + model, optimizer, sharded_state_dict, cpu_offload + ) else: optim_state_dict = optimizer.state_dict() - + if ignore_keys is not None: optim_state_dict = _remove_keys_from_optim_state_dict(optim_state_dict, model, ignore_keys) if include_keys is not None: @@ -260,9 +261,11 @@ def get_optim_state_dict( for param_ind, param_state_dict in optim_state_dict['state'].items(): optim_state_dict['state'][param_ind] = _cast_state_dict_to_precision(param_state_dict, precision) return optim_state_dict - - -def _remove_keys_from_optim_state_dict(optim_state_dict: Dict[str, Any], model: Union[ComposerModel, nn.Module], ignore_keys: Union[str, Sequence[str]]): + + +def _remove_keys_from_optim_state_dict( + optim_state_dict: Dict[str, Any], model: Union[ComposerModel, nn.Module], ignore_keys: Union[str, Sequence[str]] +): if isinstance(ignore_keys, str): ignore_keys = [ignore_keys] @@ -278,7 +281,10 @@ def _remove_keys_from_optim_state_dict(optim_state_dict: Dict[str, Any], model: return optim_state_dict -def _extract_keys_from_optim_state_dict(optim_state_dict: Dict[str, Any], model: Union[ComposerModel, nn.Module], include_keys: Union[str, Sequence[str]]): + +def _extract_keys_from_optim_state_dict( + optim_state_dict: Dict[str, Any], model: Union[ComposerModel, nn.Module], include_keys: Union[str, Sequence[str]] +): if isinstance(include_keys, str): include_keys = [include_keys] param_inds = list(optim_state_dict['state'].keys()) @@ -287,4 +293,4 @@ def _extract_keys_from_optim_state_dict(optim_state_dict: Dict[str, Any], model: if not any([fnmatch.fnmatch(param_fqn, include_key) for include_key in include_keys]): optim_state_dict['state'].pop(param_ind) - return optim_state_dict \ No newline at end of file + return optim_state_dict diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 73abfaa204..44ca051cea 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -1,20 +1,21 @@ # Copyright 2024 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import fnmatch from typing import Any, Dict import pytest import torch from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.optim import adam from composer.checkpoint import get_model_state_dict, get_optim_state_dict from composer.utils import dist from tests.common.compare import deep_compare from tests.common.markers import world_size from tests.common.models import EvenSimplerMLP, SimpleComposerMLP -from torch.optim import adam -import fnmatch + @pytest.mark.gpu @pytest.mark.parametrize('use_composer_model', [True, False]) @@ -232,15 +233,14 @@ def test_get_model_state_dict_precision_unsharded_model(precision: str, use_comp for tens in model_state_dict.values(): assert tens.dtype == precision -def _init_model_and_optimizer(use_composer_model: bool, num_classes=3, batch_size = 5, num_features = 8, take_step=True): - + +def _init_model_and_optimizer(use_composer_model: bool, num_classes=3, batch_size=5, num_features=8, take_step=True): + if use_composer_model: model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device='cuda') loss_fn = model._loss_fn else: - model = EvenSimplerMLP(num_features=num_features, - num_out_features=num_classes, - device='cuda') + model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device='cuda') loss_fn = torch.nn.CrossEntropyLoss() inputs = torch.randn(batch_size, num_features, device='cuda') @@ -259,14 +259,14 @@ def _init_model_and_optimizer(use_composer_model: bool, num_classes=3, batch_siz @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=False) - + # Before ever taking a step it should be empty. optim_state_dict = get_optim_state_dict(model, optimizer) assert optim_state_dict['state'] == optimizer.state == {} optimizer.step() optim_state_dict = get_optim_state_dict(model, optimizer) - + # Dict mapping parameter index to optimizer state for that parameter. osd_state = optim_state_dict['state'] # Dict mapping parameter itself to optimizer state for that parameter. @@ -289,23 +289,22 @@ def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): # * in the optimizer state dict params will be indices into the model's parameters list. # * in the optimizer object params will be the actual parameter tensors. deep_compare(osd_group, opt_group, ignore_keys=['params']) - - + + @pytest.mark.gpu @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_optim_state_dict_include(use_composer_model: bool): - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, - take_step=True) + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) fqns = [param_fqn for param_fqn, _ in model.named_parameters()] - include_keys=['module.0.weight'] + include_keys = ['module.0.weight'] optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) expected_optim_state_keys = [] for fqn in fqns: if any([fnmatch.fnmatch(fqn, include_key) for include_key in include_keys]): expected_optim_state_keys.append(fqns.index(fqn)) assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) - - include_keys=['module.2*'] + + include_keys = ['module.2*'] optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) expected_optim_state_keys = [] for fqn in fqns: @@ -317,18 +316,17 @@ def test_get_optim_state_dict_include(use_composer_model: bool): @pytest.mark.gpu @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_optim_state_dict_ignore(use_composer_model: bool): - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, - take_step=True) + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) fqns = [param_fqn for param_fqn, _ in model.named_parameters()] - ignore_keys=['module.0*'] + ignore_keys = ['module.0*'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) expected_optim_state_keys = [] for fqn in fqns: if not any([fnmatch.fnmatch(fqn, ignore_key) for ignore_key in ignore_keys]): expected_optim_state_keys.append(fqns.index(fqn)) assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) - - ignore_keys=['module.2.weight'] + + ignore_keys = ['module.2.weight'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) expected_optim_state_keys = [] for fqn in fqns: @@ -348,11 +346,8 @@ def test_get_optim_state_dict_ignore(use_composer_model: bool): ) @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_model_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool): - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, - take_step=True) + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) optim_state_dict = get_optim_state_dict(model, optimizer, precision=precision) for param_state in optim_state_dict['state'].values(): assert param_state['exp_avg'].dtype == precision assert param_state['exp_avg_sq'].dtype == precision - - diff --git a/tests/common/compare.py b/tests/common/compare.py index 72ac496d32..c9967ff6b5 100644 --- a/tests/common/compare.py +++ b/tests/common/compare.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import datetime -from typing import Any, Dict, List, Tuple, Union, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -24,7 +24,9 @@ def deep_compare(item1: Any, item2: Any, atol: float = 0.0, rtol: float = 0.0, i return _check_item(item1, item2, path='', atol=atol, rtol=rtol, ignore_keys=ignore_keys) -def _check_item(item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: float = 0.0, ignore_keys: Optional[List[str]] = None): +def _check_item( + item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: float = 0.0, ignore_keys: Optional[List[str]] = None +): if item1 is None: assert item2 is None, f'{path} differs: {item1} != {item2}' return @@ -89,7 +91,14 @@ def _check_list_recursively( _check_item(item1, item2, path=f'{path}/{i}', atol=atol, rtol=rtol) -def _check_dict_recursively(dict1: Dict[str, Any], dict2: Dict[str, Any], path: str, atol: float, rtol: float, ignore_keys: Optional[List[str]] = None): +def _check_dict_recursively( + dict1: Dict[str, Any], + dict2: Dict[str, Any], + path: str, + atol: float, + rtol: float, + ignore_keys: Optional[List[str]] = None +): assert len(dict1) == len(dict2), f'{path} differs: {dict1} != {dict2}' for k, val1 in dict1.items(): if ignore_keys is not None and k in ignore_keys: diff --git a/tests/common/models.py b/tests/common/models.py index b67f0e41df..779b18d8b9 100644 --- a/tests/common/models.py +++ b/tests/common/models.py @@ -117,7 +117,7 @@ def forward(self, x): # are not submodules of EvenSimplerMLP, like they are in SimpleMLP. class EvenSimplerMLP(torch.nn.Module): - def __init__(self, num_features: int, device: str, num_out_features: int=3): + def __init__(self, num_features: int, device: str, num_out_features: int = 3): super().__init__() fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) fc2 = torch.nn.Linear(num_features, num_out_features, device=device, bias=False) From 0f683b91e33ae2f612bf392ca268b1322b3d7c49 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Wed, 22 May 2024 16:01:03 -0700 Subject: [PATCH 04/15] pre-commit --- composer/checkpoint/state_dict.py | 26 ++++++++++++++++++-------- tests/checkpoint/test_state_dict.py | 2 +- tests/common/compare.py | 9 +++++++-- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 7f98b69abb..a1a6811969 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -89,7 +89,7 @@ def get_model_state_dict( return model_state_dict -def _cast_state_dict_to_precision(state_dict: Dict[str, Any], precision: Union[str, torch.dtype]): +def _cast_state_dict_to_precision(state_dict: Dict[str, Any], precision: Union[str, torch.dtype]) -> Dict[str, Any]: if isinstance(precision, str): precision = STR_TO_DTYPE[precision] @@ -157,7 +157,10 @@ def _get_model_state_dict_with_fsdp_context_manager(model: nn.Module, sharded_st def _get_optim_state_dict_with_fsdp_context_manager( - model: nn.Module, optimizer: torch.optim.Optimizer, sharded_state_dict: bool, cpu_offload: bool + model: nn.Module, + optimizer: torch.optim.Optimizer, + sharded_state_dict: bool, + cpu_offload: bool, ) -> Dict[str, Any]: """Get the optimizer state dict with the FSDP context manager. @@ -187,13 +190,13 @@ def _get_optim_state_dict_with_fsdp_context_manager( offload_to_cpu=cpu_offload, ) optim_state_dict_config = ShardedOptimStateDictConfig( - offload_to_cpu=cpu_offload + offload_to_cpu=cpu_offload, ) if sharded_state_dict else FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=cpu_offload) with FSDP.state_dict_type( model, state_dict_type=state_dict_type, state_dict_config=state_dict_config, - optim_state_dict_config=optim_state_dict_config + optim_state_dict_config=optim_state_dict_config, ): optim_state_dict = FSDP.optim_state_dict(model, optimizer) return optim_state_dict @@ -235,7 +238,7 @@ def get_optim_state_dict( if version.parse(torch.__version__) >= version.parse('2.2.0') and dist.is_initialized(): from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict log.debug('Calling torch get_optimizer_state_dict...') - optim_state_dict = get_optimizer_state_dict( + optim_state_dict: Dict[str, Any] = get_optimizer_state_dict( model=model, optimizers=optimizer, submodules=None, # We extract submodules below @@ -248,7 +251,10 @@ def get_optim_state_dict( if is_fsdp: log.debug('Calling legacy FSDP context manager to get optim state dict...') optim_state_dict = _get_optim_state_dict_with_fsdp_context_manager( - model, optimizer, sharded_state_dict, cpu_offload + model, + optimizer, + sharded_state_dict, + cpu_offload, ) else: optim_state_dict = optimizer.state_dict() @@ -264,7 +270,9 @@ def get_optim_state_dict( def _remove_keys_from_optim_state_dict( - optim_state_dict: Dict[str, Any], model: Union[ComposerModel, nn.Module], ignore_keys: Union[str, Sequence[str]] + optim_state_dict: Dict[str, Any], + model: Union[ComposerModel, nn.Module], + ignore_keys: Union[str, Sequence[str]], ): if isinstance(ignore_keys, str): ignore_keys = [ignore_keys] @@ -283,7 +291,9 @@ def _remove_keys_from_optim_state_dict( def _extract_keys_from_optim_state_dict( - optim_state_dict: Dict[str, Any], model: Union[ComposerModel, nn.Module], include_keys: Union[str, Sequence[str]] + optim_state_dict: Dict[str, Any], + model: Union[ComposerModel, nn.Module], + include_keys: Union[str, Sequence[str]], ): if isinstance(include_keys, str): include_keys = [include_keys] diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 44ca051cea..5ed6e29e2a 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -345,7 +345,7 @@ def test_get_optim_state_dict_ignore(use_composer_model: bool): ], ) @pytest.mark.parametrize('use_composer_model', [True, False]) -def test_get_model_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool): +def test_get_optim_state_dict_precision_unsharded_model(precision: str, use_composer_model: bool): model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) optim_state_dict = get_optim_state_dict(model, optimizer, precision=precision) for param_state in optim_state_dict['state'].values(): diff --git a/tests/common/compare.py b/tests/common/compare.py index c9967ff6b5..b97e2345b4 100644 --- a/tests/common/compare.py +++ b/tests/common/compare.py @@ -25,7 +25,12 @@ def deep_compare(item1: Any, item2: Any, atol: float = 0.0, rtol: float = 0.0, i def _check_item( - item1: Any, item2: Any, path: str, rtol: float = 0.0, atol: float = 0.0, ignore_keys: Optional[List[str]] = None + item1: Any, + item2: Any, + path: str, + rtol: float = 0.0, + atol: float = 0.0, + ignore_keys: Optional[List[str]] = None, ): if item1 is None: assert item2 is None, f'{path} differs: {item1} != {item2}' @@ -97,7 +102,7 @@ def _check_dict_recursively( path: str, atol: float, rtol: float, - ignore_keys: Optional[List[str]] = None + ignore_keys: Optional[List[str]] = None, ): assert len(dict1) == len(dict2), f'{path} differs: {dict1} != {dict2}' for k, val1 in dict1.items(): From 64cb443af50d2fb678dd3b5327cad67a9c8fa736 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 24 May 2024 00:07:32 +0000 Subject: [PATCH 05/15] unit tests WIP --- composer/checkpoint/state_dict.py | 45 +++++++++++------ tests/checkpoint/test_state_dict.py | 76 +++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 18 deletions(-) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index a1a6811969..d8d2aa99cd 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -217,7 +217,7 @@ def get_optim_state_dict( model: The model containing the parameters that the optimizer is optimizing. optimizer: The optimizer to get the state dict from. sharded: Whether the optimizer is sharded or not. If True, every rank returns the state dict of its shards. - If False, then rank 0 returns the state dict of the entire optimizer. + If False, then rank 0 returns the state dict of the entire optimizer and all other ranks return an empty dict. precision: The precision of the optimizer. include_keys: The list of keys to exclusively include in the state dict. If None, all keys are included. Both include_keys and ignore_keys cannot be non-None. ignore_keys: The list of keys to ignore in the state dict. If None, no keys are ignored. Both include_keys and ignore_keys cannot be non-None. @@ -259,13 +259,19 @@ def get_optim_state_dict( else: optim_state_dict = optimizer.state_dict() - if ignore_keys is not None: - optim_state_dict = _remove_keys_from_optim_state_dict(optim_state_dict, model, ignore_keys) - if include_keys is not None: - optim_state_dict = _extract_keys_from_optim_state_dict(optim_state_dict, model, include_keys) - - for param_ind, param_state_dict in optim_state_dict['state'].items(): - optim_state_dict['state'][param_ind] = _cast_state_dict_to_precision(param_state_dict, precision) + # For sharded models with non-sharded state dicts, only rank 0 has the full state dict including all the keys + target_state_dict_on_this_rank = (not sharded_state_dict and dist.get_global_rank() == 0) or sharded_state_dict + + if target_state_dict_on_this_rank: + if ignore_keys is not None: + optim_state_dict = _remove_keys_from_optim_state_dict(optim_state_dict, model, ignore_keys) + if include_keys is not None: + optim_state_dict = _extract_keys_from_optim_state_dict(optim_state_dict, model, include_keys) + + # param_key := index (0,1,2,..., len(model.parameters())-1) for unsharded models. + # param_key := fqn for sharded models. + for param_key, param_state_dict in optim_state_dict['state'].items(): + optim_state_dict['state'][param_key] = _cast_state_dict_to_precision(param_state_dict, precision) return optim_state_dict @@ -277,15 +283,24 @@ def _remove_keys_from_optim_state_dict( if isinstance(ignore_keys, str): ignore_keys = [ignore_keys] - # optim_state_dict['state'] is a dictionary mapping the param_ind (0,1,2,..., len(model.parameters())-1) - # to the optimizer state for that parameter e.g. 'step', 'exp_avg', 'exp_avg_sq'. - # The param_ind ordering is determined by passing model.parameters() + # optim_state_dict['state'] is a dictionary mapping the param_key + # to the optimizer state ( e.g. 'step', 'exp_avg', 'exp_avg_sq') for that parameter. + # For sharded models the param_key is just the fqn for the underlying model parameter, + # but for unsharded models the param_key is an index (0,1,2,..., len(model.parameters())-1) + if _is_model_fsdp(model): + for param_fqn in optim_state_dict['state'].keys(): + if any([fnmatch.fnmatch(param_fqn, ignore_key) for ignore_key in ignore_keys]): + optim_state_dict['state'].pop(param_fqn) + + # The param index ordering is determined by passing model.parameters() # to the optimizer. The underlying generator for model.parameters() is model.named_parameters() # so we need to use model.named_parameters() instead of model.state_dict().keys() to match fqn to ind correctly. - param_inds = list(optim_state_dict['state'].keys()) - for param_ind, (param_fqn, _) in zip(param_inds, model.named_parameters()): - if any([fnmatch.fnmatch(param_fqn, ignore_key) for ignore_key in ignore_keys]): - optim_state_dict['state'].pop(param_ind) + else: + param_inds = list(optim_state_dict['state'].keys()) + for param_ind, (param_fqn, _) in zip(param_inds, model.named_parameters()): + if any([fnmatch.fnmatch(param_fqn, ignore_key) for ignore_key in ignore_keys]): + optim_state_dict['state'].pop(param_ind) + return optim_state_dict diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 5ed6e29e2a..e52a173383 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -15,6 +15,7 @@ from tests.common.compare import deep_compare from tests.common.markers import world_size from tests.common.models import EvenSimplerMLP, SimpleComposerMLP +from composer.models import ComposerModel @pytest.mark.gpu @@ -234,8 +235,18 @@ def test_get_model_state_dict_precision_unsharded_model(precision: str, use_comp assert tens.dtype == precision -def _init_model_and_optimizer(use_composer_model: bool, num_classes=3, batch_size=5, num_features=8, take_step=True): +def _init_model_and_optimizer(use_composer_model: bool, num_classes=3, batch_size=5, num_features=8, take_step=True, + use_fsdp=False, tensor_type='sharded_tensor'): + model, loss_fn = _init_model(use_composer_model, num_classes=num_classes, + batch_size=batch_size, num_features=num_features, + use_fsdp=use_fsdp, tensor_type=tensor_type) + optimizer = _init_optimizer(model, loss_fn, use_composer_model=use_composer_model, num_classes=num_classes, batch_size=batch_size, num_features=num_features, take_step=take_step) + + + return model, optimizer + +def _init_model(use_composer_model: bool=False, num_classes=3, batch_size=5, num_features=8, use_fsdp=False, tensor_type='sharded_tensor'): if use_composer_model: model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device='cuda') loss_fn = model._loss_fn @@ -243,16 +254,36 @@ def _init_model_and_optimizer(use_composer_model: bool, num_classes=3, batch_siz model = EvenSimplerMLP(num_features=num_features, num_out_features=num_classes, device='cuda') loss_fn = torch.nn.CrossEntropyLoss() + if use_fsdp: + fsdp_kwargs: Dict[str, Any] = dict( + use_orig_params=True, + sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict + ) + + if tensor_type == 'dtensor': + from torch.distributed.device_mesh import init_device_mesh + device_mesh = init_device_mesh('cuda', (2,)) + fsdp_kwargs['device_mesh'] = device_mesh + + model = FSDP( + model, + **fsdp_kwargs, + ) + + return model, loss_fn + +def _init_optimizer(model, loss_fn, use_composer_model: bool=False, num_classes=3, + batch_size=5, num_features=8, take_step=True): inputs = torch.randn(batch_size, num_features, device='cuda') targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device='cuda', dtype=torch.long) batch = (inputs, targets) if use_composer_model else inputs - outputs = model(batch) optimizer = adam.Adam(model.parameters()) + outputs = model(batch) loss = loss_fn(outputs, targets) loss.backward() if take_step: optimizer.step() - return model, optimizer + return optimizer @pytest.mark.gpu @@ -277,6 +308,8 @@ def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): deep_compare(osd_param_state, opt_param_state) # Make sure the optimizer state in the state dict is the same shape as the parameter it corresponds to. + # Because model is unsharded the optimizer state should have keys corresponding to the index of the model's parameters. + # e.g. if the model has 3 parameters, the optimizer state dict keys would be (0,1,2). params = list(model.parameters()) for param_ind, param_state in osd_state.items(): param = params[param_ind] @@ -351,3 +384,40 @@ def test_get_optim_state_dict_precision_unsharded_model(precision: str, use_comp for param_state in optim_state_dict['state'].values(): assert param_state['exp_avg'].dtype == precision assert param_state['exp_avg_sq'].dtype == precision + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_dict_full_for_sharded_model(world_size, tensor_type, use_composer_model: bool): + if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): + pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') + + + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True, use_fsdp=True, tensor_type=tensor_type) + optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=False) + + + with FSDP.summon_full_params(model): + # Make sure the optimizer state in the state dict is the same shape as the parameter it corresponds to. + fqn_to_shape_map = {fqn: param.shape for fqn, param in model.named_parameters()} + if dist.get_global_rank() == 0: + # Because model is sharded, the state dict should have the same keys as the model's parameters. + for fqn, param_state in optim_state_dict['state'].items(): + model_param_shape = fqn_to_shape_map[fqn] + assert model_param_shape == param_state['exp_avg'].shape + assert model_param_shape == param_state['exp_avg_sq'].shape + + +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('tensor_type', ['sharded_tensor', 'dtensor']) +@pytest.mark.parametrize('use_composer_model', [True, False]) +def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_composer_model: bool): + if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): + pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') + + + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True, use_fsdp=True, tensor_type=tensor_type) + optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=True) \ No newline at end of file From 02d921c978989fc9807115a3dc0f580b085bbb47 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Wed, 29 May 2024 22:10:30 +0000 Subject: [PATCH 06/15] add test for sharded state dict --- tests/checkpoint/test_state_dict.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index e52a173383..dc9bedcb50 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -420,4 +420,12 @@ def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_c model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True, use_fsdp=True, tensor_type=tensor_type) - optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=True) \ No newline at end of file + model_state_dict = get_model_state_dict(model, sharded_state_dict=True) + optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=True) + + # Check to make sure on every rank optimizer state name and shape matches model's + fqn_to_shape_map = {fqn: param.shape for fqn, param in model_state_dict.items()} + for fqn, param_state in optim_state_dict['state'].items(): + model_param_shape = fqn_to_shape_map[fqn] + assert model_param_shape == param_state['exp_avg'].shape + assert model_param_shape == param_state['exp_avg_sq'].shape \ No newline at end of file From 2d7fcd0e955ff9b98608ac39631b7f06db3449a3 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Wed, 29 May 2024 15:55:52 -0700 Subject: [PATCH 07/15] pre-commit --- composer/checkpoint/state_dict.py | 21 +++-- tests/checkpoint/test_state_dict.py | 115 ++++++++++++++++++++-------- 2 files changed, 96 insertions(+), 40 deletions(-) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 149a1a55eb..f6b2d176dd 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -261,7 +261,7 @@ def get_optim_state_dict( # For sharded models with non-sharded state dicts, only rank 0 has the full state dict including all the keys target_state_dict_on_this_rank = (not sharded_state_dict and dist.get_global_rank() == 0) or sharded_state_dict - + if target_state_dict_on_this_rank: if ignore_keys is not None: optim_state_dict = _remove_keys_from_optim_state_dict(optim_state_dict, model, ignore_keys) @@ -289,8 +289,10 @@ def _remove_keys_from_optim_state_dict( # but for unsharded models the param_key is an index (0,1,2,..., len(model.parameters())-1) if _is_model_fsdp(model): for param_fqn in optim_state_dict['state'].keys(): - if any([fnmatch.fnmatch(param_fqn, ignore_key) for ignore_key in ignore_keys]): - optim_state_dict['state'].pop(param_fqn) + for ignore_key in ignore_keys: + if fnmatch.fnmatch(param_fqn, ignore_key): + optim_state_dict['state'].pop(param_fqn) + continue # The param index ordering is determined by passing model.parameters() # to the optimizer. The underlying generator for model.parameters() is model.named_parameters() @@ -298,9 +300,10 @@ def _remove_keys_from_optim_state_dict( else: param_inds = list(optim_state_dict['state'].keys()) for param_ind, (param_fqn, _) in zip(param_inds, model.named_parameters()): - if any([fnmatch.fnmatch(param_fqn, ignore_key) for ignore_key in ignore_keys]): - optim_state_dict['state'].pop(param_ind) - + for ignore_key in ignore_keys: + if fnmatch.fnmatch(param_fqn, ignore_key): + optim_state_dict['state'].pop(param_ind) + continue return optim_state_dict @@ -315,7 +318,9 @@ def _extract_keys_from_optim_state_dict( param_inds = list(optim_state_dict['state'].keys()) # See comment in _remove_keys_from_optim_state_dict. for param_ind, (param_fqn, _) in zip(param_inds, model.named_parameters()): - if not any([fnmatch.fnmatch(param_fqn, include_key) for include_key in include_keys]): - optim_state_dict['state'].pop(param_ind) + for include_key in include_keys: + if not fnmatch.fnmatch(param_fqn, include_key): + optim_state_dict['state'].pop(param_ind) + continue return optim_state_dict diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index dc9bedcb50..02a948e53c 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -15,7 +15,6 @@ from tests.common.compare import deep_compare from tests.common.markers import world_size from tests.common.models import EvenSimplerMLP, SimpleComposerMLP -from composer.models import ComposerModel @pytest.mark.gpu @@ -235,18 +234,45 @@ def test_get_model_state_dict_precision_unsharded_model(precision: str, use_comp assert tens.dtype == precision -def _init_model_and_optimizer(use_composer_model: bool, num_classes=3, batch_size=5, num_features=8, take_step=True, - use_fsdp=False, tensor_type='sharded_tensor'): - model, loss_fn = _init_model(use_composer_model, num_classes=num_classes, - batch_size=batch_size, num_features=num_features, - use_fsdp=use_fsdp, tensor_type=tensor_type) - - optimizer = _init_optimizer(model, loss_fn, use_composer_model=use_composer_model, num_classes=num_classes, batch_size=batch_size, num_features=num_features, take_step=take_step) +def _init_model_and_optimizer( + use_composer_model: bool, + num_classes=3, + batch_size=5, + num_features=8, + take_step=True, + use_fsdp=False, + tensor_type='sharded_tensor', +): + model, loss_fn = _init_model( + use_composer_model, + num_classes=num_classes, + batch_size=batch_size, + num_features=num_features, + use_fsdp=use_fsdp, + tensor_type=tensor_type, + ) + optimizer = _init_optimizer( + model, + loss_fn, + use_composer_model=use_composer_model, + num_classes=num_classes, + batch_size=batch_size, + num_features=num_features, + take_step=take_step, + ) return model, optimizer -def _init_model(use_composer_model: bool=False, num_classes=3, batch_size=5, num_features=8, use_fsdp=False, tensor_type='sharded_tensor'): + +def _init_model( + use_composer_model: bool = False, + num_classes=3, + batch_size=5, + num_features=8, + use_fsdp=False, + tensor_type='sharded_tensor', +): if use_composer_model: model = SimpleComposerMLP(num_features=num_features, num_classes=num_classes, device='cuda') loss_fn = model._loss_fn @@ -256,9 +282,9 @@ def _init_model(use_composer_model: bool=False, num_classes=3, batch_size=5, num if use_fsdp: fsdp_kwargs: Dict[str, Any] = dict( - use_orig_params=True, - sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict - ) + use_orig_params=True, + sync_module_states=True, # To enable easy comparison between rank 0 unsharded model and full state dict + ) if tensor_type == 'dtensor': from torch.distributed.device_mesh import init_device_mesh @@ -272,8 +298,16 @@ def _init_model(use_composer_model: bool=False, num_classes=3, batch_size=5, num return model, loss_fn -def _init_optimizer(model, loss_fn, use_composer_model: bool=False, num_classes=3, - batch_size=5, num_features=8, take_step=True): + +def _init_optimizer( + model, + loss_fn, + use_composer_model: bool = False, + num_classes=3, + batch_size=5, + num_features=8, + take_step=True, +): inputs = torch.randn(batch_size, num_features, device='cuda') targets = torch.randint(low=0, high=num_classes, size=(batch_size,), device='cuda', dtype=torch.long) batch = (inputs, targets) if use_composer_model else inputs @@ -333,16 +367,20 @@ def test_get_optim_state_dict_include(use_composer_model: bool): optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) expected_optim_state_keys = [] for fqn in fqns: - if any([fnmatch.fnmatch(fqn, include_key) for include_key in include_keys]): - expected_optim_state_keys.append(fqns.index(fqn)) + for include_key in include_keys: + if fnmatch.fnmatch(fqn, include_key): + expected_optim_state_keys.append(fqns.index(fqn)) + continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) include_keys = ['module.2*'] optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) expected_optim_state_keys = [] for fqn in fqns: - if any([fnmatch.fnmatch(fqn, include_key) for include_key in include_keys]): - expected_optim_state_keys.append(fqns.index(fqn)) + for include_key in include_keys: + if fnmatch.fnmatch(fqn, include_key): + expected_optim_state_keys.append(fqns.index(fqn)) + continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) @@ -353,18 +391,24 @@ def test_get_optim_state_dict_ignore(use_composer_model: bool): fqns = [param_fqn for param_fqn, _ in model.named_parameters()] ignore_keys = ['module.0*'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = [] + expected_optim_state_keys = [*fqns] for fqn in fqns: - if not any([fnmatch.fnmatch(fqn, ignore_key) for ignore_key in ignore_keys]): - expected_optim_state_keys.append(fqns.index(fqn)) + for ignore_key in ignore_keys: + if fnmatch.fnmatch(fqn, ignore_key): + expected_optim_state_keys.remove(fqn) + continue + assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) ignore_keys = ['module.2.weight'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = [] + expected_optim_state_keys = [*fqns] for fqn in fqns: - if not any([fnmatch.fnmatch(fqn, ignore_key) for ignore_key in ignore_keys]): - expected_optim_state_keys.append(fqns.index(fqn)) + for ignore_key in ignore_keys: + if fnmatch.fnmatch(fqn, ignore_key): + expected_optim_state_keys.remove(fqn) + continue + assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) @@ -393,11 +437,14 @@ def test_get_optim_state_dict_precision_unsharded_model(precision: str, use_comp def test_get_optim_dict_full_for_sharded_model(world_size, tensor_type, use_composer_model: bool): if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') - - - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True, use_fsdp=True, tensor_type=tensor_type) - optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=False) + model, optimizer = _init_model_and_optimizer( + use_composer_model=use_composer_model, + take_step=True, + use_fsdp=True, + tensor_type=tensor_type, + ) + optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=False) with FSDP.summon_full_params(model): # Make sure the optimizer state in the state dict is the same shape as the parameter it corresponds to. @@ -417,9 +464,13 @@ def test_get_optim_dict_full_for_sharded_model(world_size, tensor_type, use_comp def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_composer_model: bool): if tensor_type == 'dtensor' and version.parse(torch.__version__) < version.parse('2.2.0'): pytest.skip('DTensor is only supported in PyTorch >= 2.2.0') - - - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True, use_fsdp=True, tensor_type=tensor_type) + + model, optimizer = _init_model_and_optimizer( + use_composer_model=use_composer_model, + take_step=True, + use_fsdp=True, + tensor_type=tensor_type, + ) model_state_dict = get_model_state_dict(model, sharded_state_dict=True) optim_state_dict = get_optim_state_dict(model, optimizer, sharded_state_dict=True) @@ -428,4 +479,4 @@ def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_c for fqn, param_state in optim_state_dict['state'].items(): model_param_shape = fqn_to_shape_map[fqn] assert model_param_shape == param_state['exp_avg'].shape - assert model_param_shape == param_state['exp_avg_sq'].shape \ No newline at end of file + assert model_param_shape == param_state['exp_avg_sq'].shape From 59ad794168a86ba73664b1310e726824a0e7bebc Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Wed, 29 May 2024 23:05:41 +0000 Subject: [PATCH 08/15] fix tests --- tests/checkpoint/test_state_dict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 02a948e53c..590a377d0b 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -391,22 +391,22 @@ def test_get_optim_state_dict_ignore(use_composer_model: bool): fqns = [param_fqn for param_fqn, _ in model.named_parameters()] ignore_keys = ['module.0*'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = [*fqns] + expected_optim_state_keys = list(range(len(fqns))) for fqn in fqns: for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): - expected_optim_state_keys.remove(fqn) + expected_optim_state_keys.remove(fqns.index(fqn)) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) ignore_keys = ['module.2.weight'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = [*fqns] + expected_optim_state_keys = list(range(len(fqns))) for fqn in fqns: for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): - expected_optim_state_keys.remove(fqn) + expected_optim_state_keys.remove(fqns.index(fqn)) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) From 2d05e7113b58417361490df117a4570e2a1b67d7 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Thu, 30 May 2024 00:00:57 +0000 Subject: [PATCH 09/15] fix? --- tests/checkpoint/test_state_dict.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 590a377d0b..efe075d3bc 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -323,13 +323,7 @@ def _init_optimizer( @pytest.mark.gpu @pytest.mark.parametrize('use_composer_model', [True, False]) def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=False) - - # Before ever taking a step it should be empty. - optim_state_dict = get_optim_state_dict(model, optimizer) - assert optim_state_dict['state'] == optimizer.state == {} - - optimizer.step() + model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) optim_state_dict = get_optim_state_dict(model, optimizer) # Dict mapping parameter index to optimizer state for that parameter. @@ -369,7 +363,7 @@ def test_get_optim_state_dict_include(use_composer_model: bool): for fqn in fqns: for include_key in include_keys: if fnmatch.fnmatch(fqn, include_key): - expected_optim_state_keys.append(fqns.index(fqn)) + expected_optim_state_keys.append(fqn) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) @@ -379,7 +373,7 @@ def test_get_optim_state_dict_include(use_composer_model: bool): for fqn in fqns: for include_key in include_keys: if fnmatch.fnmatch(fqn, include_key): - expected_optim_state_keys.append(fqns.index(fqn)) + expected_optim_state_keys.append(fqn) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) @@ -391,22 +385,22 @@ def test_get_optim_state_dict_ignore(use_composer_model: bool): fqns = [param_fqn for param_fqn, _ in model.named_parameters()] ignore_keys = ['module.0*'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = list(range(len(fqns))) + expected_optim_state_keys = [*fqns] for fqn in fqns: for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): - expected_optim_state_keys.remove(fqns.index(fqn)) + expected_optim_state_keys.remove(fqn) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) ignore_keys = ['module.2.weight'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = list(range(len(fqns))) + expected_optim_state_keys = [*fqns] for fqn in fqns: for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): - expected_optim_state_keys.remove(fqns.index(fqn)) + expected_optim_state_keys.remove(fqn) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) From 6131e387bbdf65a62b1b713e20023bf8e04bc620 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Thu, 30 May 2024 16:03:33 -0700 Subject: [PATCH 10/15] remove comma --- composer/checkpoint/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/checkpoint/__init__.py b/composer/checkpoint/__init__.py index a22c24e9df..d553b6afe3 100644 --- a/composer/checkpoint/__init__.py +++ b/composer/checkpoint/__init__.py @@ -3,7 +3,7 @@ """Module for checkpointing API.""" -from composer.checkpoint.state_dict import get_model_state_dict, get_optim_state_dict, get_metadata_state_dict, +from composer.checkpoint.state_dict import get_model_state_dict, get_optim_state_dict, get_metadata_state_dict __all__ = [ 'get_model_state_dict', From 7faf4a16eafc09e9fa5a515f221d153990717978 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 31 May 2024 00:25:40 +0000 Subject: [PATCH 11/15] fix tests --- composer/checkpoint/state_dict.py | 45 +++++++++++++++++++---------- tests/checkpoint/test_state_dict.py | 23 +++++++++++---- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 1a41d235c6..ab2813c505 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -6,7 +6,7 @@ import fnmatch import logging import sys -from typing import Any, Optional, Sequence, Union, Dict +from typing import Any, Dict, Optional, Sequence, Union import torch from packaging import version @@ -289,7 +289,15 @@ def _remove_keys_from_optim_state_dict( # to the optimizer state ( e.g. 'step', 'exp_avg', 'exp_avg_sq') for that parameter. # For sharded models the param_key is just the fqn for the underlying model parameter, # but for unsharded models the param_key is an index (0,1,2,..., len(model.parameters())-1) - if _is_model_fsdp(model): + param_keys = list(optim_state_dict['state'].keys()) + optim_keyed_by_ind = type(list(param_keys)[0]) == int + if optim_keyed_by_ind: + for param_ind, (param_fqn, _) in zip(param_keys, model.named_parameters()): + for ignore_key in ignore_keys: + if fnmatch.fnmatch(param_fqn, ignore_key): + optim_state_dict['state'].pop(param_ind) + continue + else: for param_fqn in optim_state_dict['state'].keys(): for ignore_key in ignore_keys: if fnmatch.fnmatch(param_fqn, ignore_key): @@ -299,13 +307,6 @@ def _remove_keys_from_optim_state_dict( # The param index ordering is determined by passing model.parameters() # to the optimizer. The underlying generator for model.parameters() is model.named_parameters() # so we need to use model.named_parameters() instead of model.state_dict().keys() to match fqn to ind correctly. - else: - param_inds = list(optim_state_dict['state'].keys()) - for param_ind, (param_fqn, _) in zip(param_inds, model.named_parameters()): - for ignore_key in ignore_keys: - if fnmatch.fnmatch(param_fqn, ignore_key): - optim_state_dict['state'].pop(param_ind) - continue return optim_state_dict @@ -317,15 +318,27 @@ def _extract_keys_from_optim_state_dict( ): if isinstance(include_keys, str): include_keys = [include_keys] - param_inds = list(optim_state_dict['state'].keys()) - # See comment in _remove_keys_from_optim_state_dict. - for param_ind, (param_fqn, _) in zip(param_inds, model.named_parameters()): - for include_key in include_keys: - if not fnmatch.fnmatch(param_fqn, include_key): - optim_state_dict['state'].pop(param_ind) - continue + + param_keys = list(optim_state_dict['state'].keys()) + optim_keyed_by_ind = type(list(param_keys)[0]) == int + + if optim_keyed_by_ind: + # See comment in _remove_keys_from_optim_state_dict. + for param_ind, (param_fqn, _) in zip(param_keys, model.named_parameters()): + for include_key in include_keys: + if not fnmatch.fnmatch(param_fqn, include_key): + optim_state_dict['state'].pop(param_ind) + continue + else: + for param_fqn in optim_state_dict['state'].keys(): + for ignore_key in include_keys: + if fnmatch.fnmatch(param_fqn, ignore_key): + optim_state_dict['state'].pop(param_fqn) + continue return optim_state_dict + + def get_metadata_state_dict( model: Optional[Union[ComposerModel, nn.Module]] = None, sharded_state_dict: Optional[bool] = None, diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index c8024931d1..bbbf5c63cd 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -360,11 +360,14 @@ def test_get_optim_state_dict_include(use_composer_model: bool): fqns = [param_fqn for param_fqn, _ in model.named_parameters()] include_keys = ['module.0.weight'] optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) + param_keys = list(optim_state_dict['state'].keys()) + optim_keyed_by_ind = type(list(param_keys)[0]) == int expected_optim_state_keys = [] for fqn in fqns: for include_key in include_keys: if fnmatch.fnmatch(fqn, include_key): - expected_optim_state_keys.append(fqn) + key = fqns.index(fqn) if optim_keyed_by_ind else fqn + expected_optim_state_keys.append(key) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) @@ -374,7 +377,8 @@ def test_get_optim_state_dict_include(use_composer_model: bool): for fqn in fqns: for include_key in include_keys: if fnmatch.fnmatch(fqn, include_key): - expected_optim_state_keys.append(fqn) + key = fqns.index(fqn) if optim_keyed_by_ind else fqn + expected_optim_state_keys.append(key) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) @@ -386,22 +390,27 @@ def test_get_optim_state_dict_ignore(use_composer_model: bool): fqns = [param_fqn for param_fqn, _ in model.named_parameters()] ignore_keys = ['module.0*'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = [*fqns] + param_keys = list(optim_state_dict['state'].keys()) + optim_keyed_by_ind = type(list(param_keys)[0]) == int + + expected_optim_state_keys = list(range(len(fqns))) if optim_keyed_by_ind else [*fqns] for fqn in fqns: for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): - expected_optim_state_keys.remove(fqn) + key = fqns.index(fqn) if optim_keyed_by_ind else fqn + expected_optim_state_keys.remove(key) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) ignore_keys = ['module.2.weight'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = [*fqns] + expected_optim_state_keys = list(range(len(fqns))) if optim_keyed_by_ind else [*fqns] for fqn in fqns: for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): - expected_optim_state_keys.remove(fqn) + key = fqns.index(fqn) if optim_keyed_by_ind else fqn + expected_optim_state_keys.remove(key) continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) @@ -475,6 +484,8 @@ def test_get_optim_dict_sharded_for_sharded_model(world_size, tensor_type, use_c model_param_shape = fqn_to_shape_map[fqn] assert model_param_shape == param_state['exp_avg'].shape assert model_param_shape == param_state['exp_avg_sq'].shape + + @pytest.mark.gpu @world_size(1, 2) def test_get_metadata_empty_call(world_size): From 68f8d1bd6c4070abb531f501b9b8dfb0571099f4 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Thu, 30 May 2024 17:31:01 -0700 Subject: [PATCH 12/15] pre-commit --- composer/checkpoint/__init__.py | 2 +- tests/checkpoint/test_state_dict.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/composer/checkpoint/__init__.py b/composer/checkpoint/__init__.py index d553b6afe3..84d6c9f4cf 100644 --- a/composer/checkpoint/__init__.py +++ b/composer/checkpoint/__init__.py @@ -3,7 +3,7 @@ """Module for checkpointing API.""" -from composer.checkpoint.state_dict import get_model_state_dict, get_optim_state_dict, get_metadata_state_dict +from composer.checkpoint.state_dict import get_metadata_state_dict, get_model_state_dict, get_optim_state_dict __all__ = [ 'get_model_state_dict', diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index bbbf5c63cd..1197483f5a 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -398,7 +398,7 @@ def test_get_optim_state_dict_ignore(use_composer_model: bool): for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): key = fqns.index(fqn) if optim_keyed_by_ind else fqn - expected_optim_state_keys.remove(key) + expected_optim_state_keys.remove(key) # pyright:ignore continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) @@ -410,7 +410,7 @@ def test_get_optim_state_dict_ignore(use_composer_model: bool): for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): key = fqns.index(fqn) if optim_keyed_by_ind else fqn - expected_optim_state_keys.remove(key) + expected_optim_state_keys.remove(key) # pyright:ignore continue assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) From d191953eec05c44064197e25ef4081ec9ff46149 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 31 May 2024 10:54:33 -0700 Subject: [PATCH 13/15] try this for tests? --- composer/checkpoint/state_dict.py | 34 ++++------------ tests/checkpoint/test_state_dict.py | 63 ++++++++++++----------------- 2 files changed, 33 insertions(+), 64 deletions(-) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index ab2813c505..6cf853f8d8 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -292,21 +292,11 @@ def _remove_keys_from_optim_state_dict( param_keys = list(optim_state_dict['state'].keys()) optim_keyed_by_ind = type(list(param_keys)[0]) == int if optim_keyed_by_ind: - for param_ind, (param_fqn, _) in zip(param_keys, model.named_parameters()): + for param_key, (param_fqn, _) in zip(param_keys, model.named_parameters()): for ignore_key in ignore_keys: if fnmatch.fnmatch(param_fqn, ignore_key): - optim_state_dict['state'].pop(param_ind) + optim_state_dict['state'].pop(param_key) continue - else: - for param_fqn in optim_state_dict['state'].keys(): - for ignore_key in ignore_keys: - if fnmatch.fnmatch(param_fqn, ignore_key): - optim_state_dict['state'].pop(param_fqn) - continue - - # The param index ordering is determined by passing model.parameters() - # to the optimizer. The underlying generator for model.parameters() is model.named_parameters() - # so we need to use model.named_parameters() instead of model.state_dict().keys() to match fqn to ind correctly. return optim_state_dict @@ -320,21 +310,13 @@ def _extract_keys_from_optim_state_dict( include_keys = [include_keys] param_keys = list(optim_state_dict['state'].keys()) - optim_keyed_by_ind = type(list(param_keys)[0]) == int - if optim_keyed_by_ind: - # See comment in _remove_keys_from_optim_state_dict. - for param_ind, (param_fqn, _) in zip(param_keys, model.named_parameters()): - for include_key in include_keys: - if not fnmatch.fnmatch(param_fqn, include_key): - optim_state_dict['state'].pop(param_ind) - continue - else: - for param_fqn in optim_state_dict['state'].keys(): - for ignore_key in include_keys: - if fnmatch.fnmatch(param_fqn, ignore_key): - optim_state_dict['state'].pop(param_fqn) - continue + # See comment in _remove_keys_from_optim_state_dict. + for param_key, (param_fqn, _) in zip(param_keys, model.named_parameters()): + for include_key in include_keys: + if not fnmatch.fnmatch(param_fqn, include_key): + optim_state_dict['state'].pop(param_key) + continue return optim_state_dict diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 1197483f5a..b9d6759131 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -340,8 +340,12 @@ def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): # Because model is unsharded the optimizer state should have keys corresponding to the index of the model's parameters. # e.g. if the model has 3 parameters, the optimizer state dict keys would be (0,1,2). params = list(model.parameters()) - for param_ind, param_state in osd_state.items(): - param = params[param_ind] + param_dict = dict(list(model.named_parameters())) + for param_key, param_state in osd_state.items(): + if isinstance(param_key, str): + param = param_dict[param_key] + else: + param = params[param_key] assert param.shape == param_state['exp_avg'].shape assert param.shape == param_state['exp_avg_sq'].shape @@ -361,26 +365,16 @@ def test_get_optim_state_dict_include(use_composer_model: bool): include_keys = ['module.0.weight'] optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) param_keys = list(optim_state_dict['state'].keys()) - optim_keyed_by_ind = type(list(param_keys)[0]) == int - expected_optim_state_keys = [] - for fqn in fqns: + expected_optim_keys = [] + for optim_key, fqn in zip(param_keys, fqns): for include_key in include_keys: if fnmatch.fnmatch(fqn, include_key): - key = fqns.index(fqn) if optim_keyed_by_ind else fqn - expected_optim_state_keys.append(key) + if isinstance(optim_key, str): + expected_optim_keys.append(optim_key) + else: + expected_optim_keys.append(fqn) continue - assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) - - include_keys = ['module.2*'] - optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) - expected_optim_state_keys = [] - for fqn in fqns: - for include_key in include_keys: - if fnmatch.fnmatch(fqn, include_key): - key = fqns.index(fqn) if optim_keyed_by_ind else fqn - expected_optim_state_keys.append(key) - continue - assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) + assert set(optim_state_dict['state'].keys()) == set(expected_optim_keys) @pytest.mark.gpu @@ -388,32 +382,25 @@ def test_get_optim_state_dict_include(use_composer_model: bool): def test_get_optim_state_dict_ignore(use_composer_model: bool): model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) fqns = [param_fqn for param_fqn, _ in model.named_parameters()] + ignore_keys = ['module.0*'] optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) param_keys = list(optim_state_dict['state'].keys()) - optim_keyed_by_ind = type(list(param_keys)[0]) == int - - expected_optim_state_keys = list(range(len(fqns))) if optim_keyed_by_ind else [*fqns] - for fqn in fqns: - for ignore_key in ignore_keys: - if fnmatch.fnmatch(fqn, ignore_key): - key = fqns.index(fqn) if optim_keyed_by_ind else fqn - expected_optim_state_keys.remove(key) # pyright:ignore - continue - - assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) - - ignore_keys = ['module.2.weight'] - optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - expected_optim_state_keys = list(range(len(fqns))) if optim_keyed_by_ind else [*fqns] - for fqn in fqns: + expected_optim_state_keys_inds = list(range(len(fqns))) + expected_optim_state_keys_str = [*fqns] + expected_optim_keys = [] + for optim_key, fqn in zip(param_keys, fqns): for ignore_key in ignore_keys: if fnmatch.fnmatch(fqn, ignore_key): - key = fqns.index(fqn) if optim_keyed_by_ind else fqn - expected_optim_state_keys.remove(key) # pyright:ignore + if isinstance(optim_key, str): + expected_optim_state_keys_str.remove(optim_key) # pyright:ignore + expected_optim_keys = expected_optim_state_keys_str + else: + expected_optim_state_keys_inds.remove(optim_key) + expected_optim_keys = expected_optim_state_keys_inds continue - assert set(optim_state_dict['state'].keys()) == set(expected_optim_state_keys) + assert set(optim_state_dict['state'].keys()) == set(expected_optim_keys) @pytest.mark.gpu From 99fef1f40241b9f8423993ea28ab774ba2dfbe79 Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 31 May 2024 18:18:45 +0000 Subject: [PATCH 14/15] remove ignore and include tests --- composer/checkpoint/state_dict.py | 48 ++--------------------------- tests/checkpoint/test_state_dict.py | 46 --------------------------- 2 files changed, 2 insertions(+), 92 deletions(-) diff --git a/composer/checkpoint/state_dict.py b/composer/checkpoint/state_dict.py index 6cf853f8d8..511fba5a9f 100644 --- a/composer/checkpoint/state_dict.py +++ b/composer/checkpoint/state_dict.py @@ -266,9 +266,9 @@ def get_optim_state_dict( if target_state_dict_on_this_rank: if ignore_keys is not None: - optim_state_dict = _remove_keys_from_optim_state_dict(optim_state_dict, model, ignore_keys) + raise NotImplementedError('Ignoring keys in the optimizer state dict is not supported yet.') if include_keys is not None: - optim_state_dict = _extract_keys_from_optim_state_dict(optim_state_dict, model, include_keys) + raise NotImplementedError('Ignoring keys in the optimizer state dict is not supported yet.') # param_key := index (0,1,2,..., len(model.parameters())-1) for unsharded models. # param_key := fqn for sharded models. @@ -277,50 +277,6 @@ def get_optim_state_dict( return optim_state_dict -def _remove_keys_from_optim_state_dict( - optim_state_dict: Dict[str, Any], - model: Union[ComposerModel, nn.Module], - ignore_keys: Union[str, Sequence[str]], -): - if isinstance(ignore_keys, str): - ignore_keys = [ignore_keys] - - # optim_state_dict['state'] is a dictionary mapping the param_key - # to the optimizer state ( e.g. 'step', 'exp_avg', 'exp_avg_sq') for that parameter. - # For sharded models the param_key is just the fqn for the underlying model parameter, - # but for unsharded models the param_key is an index (0,1,2,..., len(model.parameters())-1) - param_keys = list(optim_state_dict['state'].keys()) - optim_keyed_by_ind = type(list(param_keys)[0]) == int - if optim_keyed_by_ind: - for param_key, (param_fqn, _) in zip(param_keys, model.named_parameters()): - for ignore_key in ignore_keys: - if fnmatch.fnmatch(param_fqn, ignore_key): - optim_state_dict['state'].pop(param_key) - continue - - return optim_state_dict - - -def _extract_keys_from_optim_state_dict( - optim_state_dict: Dict[str, Any], - model: Union[ComposerModel, nn.Module], - include_keys: Union[str, Sequence[str]], -): - if isinstance(include_keys, str): - include_keys = [include_keys] - - param_keys = list(optim_state_dict['state'].keys()) - - # See comment in _remove_keys_from_optim_state_dict. - for param_key, (param_fqn, _) in zip(param_keys, model.named_parameters()): - for include_key in include_keys: - if not fnmatch.fnmatch(param_fqn, include_key): - optim_state_dict['state'].pop(param_key) - continue - - return optim_state_dict - - def get_metadata_state_dict( model: Optional[Union[ComposerModel, nn.Module]] = None, sharded_state_dict: Optional[bool] = None, diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index b9d6759131..6c462803ad 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -357,52 +357,6 @@ def test_get_optim_state_dict_unsharded_model(use_composer_model: bool): deep_compare(osd_group, opt_group, ignore_keys=['params']) -@pytest.mark.gpu -@pytest.mark.parametrize('use_composer_model', [True, False]) -def test_get_optim_state_dict_include(use_composer_model: bool): - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) - fqns = [param_fqn for param_fqn, _ in model.named_parameters()] - include_keys = ['module.0.weight'] - optim_state_dict = get_optim_state_dict(model, optimizer, include_keys=include_keys) - param_keys = list(optim_state_dict['state'].keys()) - expected_optim_keys = [] - for optim_key, fqn in zip(param_keys, fqns): - for include_key in include_keys: - if fnmatch.fnmatch(fqn, include_key): - if isinstance(optim_key, str): - expected_optim_keys.append(optim_key) - else: - expected_optim_keys.append(fqn) - continue - assert set(optim_state_dict['state'].keys()) == set(expected_optim_keys) - - -@pytest.mark.gpu -@pytest.mark.parametrize('use_composer_model', [True, False]) -def test_get_optim_state_dict_ignore(use_composer_model: bool): - model, optimizer = _init_model_and_optimizer(use_composer_model=use_composer_model, take_step=True) - fqns = [param_fqn for param_fqn, _ in model.named_parameters()] - - ignore_keys = ['module.0*'] - optim_state_dict = get_optim_state_dict(model, optimizer, ignore_keys=ignore_keys) - param_keys = list(optim_state_dict['state'].keys()) - expected_optim_state_keys_inds = list(range(len(fqns))) - expected_optim_state_keys_str = [*fqns] - expected_optim_keys = [] - for optim_key, fqn in zip(param_keys, fqns): - for ignore_key in ignore_keys: - if fnmatch.fnmatch(fqn, ignore_key): - if isinstance(optim_key, str): - expected_optim_state_keys_str.remove(optim_key) # pyright:ignore - expected_optim_keys = expected_optim_state_keys_str - else: - expected_optim_state_keys_inds.remove(optim_key) - expected_optim_keys = expected_optim_state_keys_inds - continue - - assert set(optim_state_dict['state'].keys()) == set(expected_optim_keys) - - @pytest.mark.gpu @pytest.mark.parametrize( 'precision', From 4848777c87e1f9f2a4844570528f9c7db266357a Mon Sep 17 00:00:00 2001 From: Evan Racah Date: Fri, 31 May 2024 11:22:36 -0700 Subject: [PATCH 15/15] pre-commit --- tests/checkpoint/test_state_dict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 6c462803ad..9618756b83 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -1,7 +1,6 @@ # Copyright 2024 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -import fnmatch from typing import Any, Dict import pytest