Skip to content

Commit

Permalink
Merge branch 'dev' into release/v0.22.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed May 1, 2024
2 parents 0ccad56 + 5eddaf3 commit fe7964f
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/daily.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ jobs:
composer_package_name: ${{ matrix.composer_package_name }}
container: ${{ matrix.container }}
git_repo: mosaicml/composer
mcloud-timeout: 3600
mcloud-timeout: 5400
name: ${{ matrix.name }}
pip_deps: "[all]"
pytest-command: ${{ matrix.pytest_command }}
Expand Down
20 changes: 14 additions & 6 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,13 +620,21 @@ def load_sharded_checkpoint(

# We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph.
with torch.no_grad():
# 1. Load model and metadata first
# 1. Load metadata first for backwards compatability check
# We need to check if the "optimizers" is at the root of the state dict to determine
# how to load the optimizer state.
metadata = storage_reader.read_metadata()
# Retrieve all top-level keys of the metadata.
top_level_keys = [v[0] for v in metadata.planner_data.values()]
optimizers_at_root = 'optimizers' in top_level_keys

# 2. Load model and metadata
if load_weights_only:
state_dict: Dict[str, Any] = {'state': {'model': state.get_model_state_dict()}}
else:
cur_state_dict = state.state_dict()
# For older versions of torch, we load optimizer separately.
if version.parse(torch.__version__) < version.parse('2.2.3'):
# If 'optimizers' is at root-level, we load it separately.
if optimizers_at_root:
cur_state_dict.pop('optimizers')
num_rng_ranks = _get_num_ranks_that_saved_rng(storage_reader.read_metadata())
state_dict: Dict[str, Any] = {
Expand Down Expand Up @@ -661,9 +669,9 @@ def load_sharded_checkpoint(
algorithm_passes=algorithm_passes,
)

# 2. Optionally load optimizer
# if we are using later than 2.2.3 then optimizer will already be loaded
if version.parse(torch.__version__) < version.parse('2.2.3') and not load_weights_only:
# 3. Optionally load optimizer
# If 'optimizers' was not at root-level, then it will already be loaded
if optimizers_at_root and not load_weights_only:
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=state.state_dict()['model'],
optimizer_key='optimizers',
Expand Down
35 changes: 16 additions & 19 deletions tests/algorithms/algorithm_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,6 @@
'kwargs': {},
}

simple_resnet_settings = {
'model': (
composer_resnet,
{
'model_name': 'resnet18',
'num_classes': 2,
},
),
'dataset': (
RandomImageDataset,
{
'shape': (3, 224, 224),
},
),
'kwargs': {},
}

_settings: Dict[Type[Algorithm], Optional[Dict[str, Any]]] = {
GradientClipping: {
'model': SimpleConvModel,
Expand Down Expand Up @@ -158,7 +141,15 @@
'half_life': '1ba',
},
},
Factorize: simple_resnet_settings,
Factorize: {
'model': SimpleConvModel,
'dataset': RandomImageDataset,
'kwargs': {
'min_channels': 4,
'min_features': 4,
'latent_features': 2,
},
},
GatedLinearUnits: simple_bert_settings,
GhostBatchNorm: {
'model': (
Expand Down Expand Up @@ -205,7 +196,13 @@
'max_seq_length': 16,
},
},
SqueezeExcite: simple_resnet_settings,
SqueezeExcite: {
'model': SimpleConvModel,
'dataset': RandomImageDataset,
'kwargs': {
'min_channels': 4,
},
},
StochasticDepth: {
'model': (
composer_resnet,
Expand Down
1 change: 1 addition & 0 deletions tests/algorithms/test_algorithms_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

@pytest.mark.gpu
@pytest.mark.parametrize('alg_cls', get_algs_with_marks())
@pytest.mark.filterwarnings(r'ignore:.*Plan failed with a cudnnException.*:UserWarning') # Torch 2.3 regression
def test_algorithm_trains(alg_cls: Type[Algorithm]):
alg_kwargs = get_alg_kwargs(alg_cls)
model = get_alg_model(alg_cls)
Expand Down
12 changes: 9 additions & 3 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,6 @@ def test_fsdp_load_old_checkpoint(
'state': trainer2.state.state_dict(),
'rng': get_rng_state(),
}
if version.parse(torch.__version__) < version.parse('2.2.3'):
state_dict['state'].pop('optimizers')

object_store = S3ObjectStore(bucket=f'{s3_bucket}')
storage_reader = DistCPObjectStoreReader(
Expand All @@ -538,14 +536,22 @@ def test_fsdp_load_old_checkpoint(
device_mesh=None,
)

# Load metadata first, and check if 'optimizers' is a top-level key. Pop if it is.
metadata = storage_reader.read_metadata()
# Retrieve all top-level keys of the metadata.
top_level_keys = [v[0] for v in metadata.planner_data.values()]
optimizers_at_root = 'optimizers' in top_level_keys
if optimizers_at_root:
state_dict['state'].pop('optimizers')

process_group = None
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
planner=None,
process_group=process_group,
)
if version.parse(torch.__version__) < version.parse('2.2.3'):
if optimizers_at_root:
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model_state_dict = state_dict['state']['model']
Expand Down

0 comments on commit fe7964f

Please sign in to comment.