Skip to content

Commit

Permalink
Fix failing sharded_checkpoint tests that fail when pytorch 1.13 is n…
Browse files Browse the repository at this point in the history
…ot installed (#1988)
  • Loading branch information
eracah authored Feb 23, 2023
1 parent be40689 commit 33a28b5
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tests/trainer/test_sharded_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def _compare_model_params_between_state_dicts(state_dict1, state_dict2):

@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='requires PyTorch 1.13 or higher')
def test_fsdp_full_state_dict_save(world_size, tmp_path: pathlib.Path):

save_folder = tmp_path
save_filename = 'rank{rank}.pt'
num_features = 3
Expand Down Expand Up @@ -176,8 +177,9 @@ def test_fsdp_full_state_dict_save(world_size, tmp_path: pathlib.Path):

@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='requires PyTorch 1.13 or higher')
def test_fsdp_full_state_dict_load(world_size, tmp_path: pathlib.Path):

save_folder = tmp_path
save_filename = 'rank{rank}.pt'
trainer1 = get_trainer(save_folder=str(save_folder), save_filename=save_filename, fsdp_state_dict_type='full')
Expand All @@ -197,10 +199,9 @@ def test_fsdp_full_state_dict_load(world_size, tmp_path: pathlib.Path):
@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('state_dict_type', ['local', 'sharded'])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='requires PyTorch 1.13 or higher')
def test_fsdp_partitioned_state_dict_save(world_size, tmp_path: pathlib.Path, state_dict_type: str):

if version.parse(torch.__version__) < version.parse('1.13.0'):
pytest.skip()
pytest.importorskip('torch.distributed.fsdp.fully_sharded_data_parallel')
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardedTensor
save_folder = tmp_path
Expand Down Expand Up @@ -304,8 +305,9 @@ def test_fsdp_partitioned_state_dict_save(world_size, tmp_path: pathlib.Path, st
@pytest.mark.gpu
@world_size(2)
@pytest.mark.parametrize('state_dict_type', ['local', 'sharded'])
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='requires PyTorch 1.13 or higher')
def test_fsdp_partitioned_state_dict_load(world_size, tmp_path: pathlib.Path, state_dict_type: str):

save_folder = tmp_path
save_filename = 'rank{rank}.pt'
trainer1 = get_trainer(save_folder=str(save_folder),
Expand Down

0 comments on commit 33a28b5

Please sign in to comment.