Skip to content

Commit

Permalink
Port resnet data loading optimizations to SPMD test script (#5386)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 authored Aug 1, 2023
1 parent 2676a20 commit 424d8c8
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions test/spmd/test_train_spmd_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@
'--test_only_at_end': {
'action': 'store_true',
},
'--persistent_workers': {
'action': 'store_true',
},
'--prefetch_factor': {
'type': int,
},
'--loader_prefetch_size': {
'type': int,
},
'--device_prefetch_size': {
'type': int,
},
'--host_to_device_transfer_threads': {
'type': int,
},
'--sharding': {
'choices': ['batch', 'spatial', 'conv', 'linear'],
'nargs': '+',
Expand Down Expand Up @@ -78,7 +93,14 @@
momentum=0.9,
lr=0.1,
target_accuracy=0.0,
persistent_workers=False,
prefetch_factor=16,
loader_prefetch_size=8,
device_prefetch_size=4,
num_workers=8,
host_to_device_transfer_threads=1,
)

MODEL_SPECIFIC_DEFAULTS = {
# Override some of the args in DEFAULT_KWARGS, or add them to the dict
# if they don't exist.
Expand Down Expand Up @@ -163,19 +185,22 @@ def train_imagenet():
normalize,
]))

# For single-host SPMD, no data sampler is needed.
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=FLAGS.batch_size,
sampler=None,
drop_last=FLAGS.drop_last,
shuffle=True)
shuffle=True,
num_workers=FLAGS.num_workers,
persistent_workers=FLAGS.persistent_workers,
prefetch_factor=FLAGS.prefetch_factor)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=FLAGS.test_set_batch_size,
sampler=None,
drop_last=FLAGS.drop_last,
shuffle=False)
shuffle=True,
num_workers=FLAGS.num_workers,
persistent_workers=FLAGS.persistent_workers,
prefetch_factor=FLAGS.prefetch_factor)

torch.manual_seed(42)

Expand Down Expand Up @@ -251,7 +276,10 @@ def train_imagenet():
train_loader = pl.MpDeviceLoader(
train_loader,
device,
input_sharding=xs.ShardingSpec(input_mesh, (0, 1, 2, 3)))
input_sharding=xs.ShardingSpec(input_mesh, (0, 1, 2, 3)),
loader_prefetch_size=FLAGS.loader_prefetch_size,
device_prefetch_size=FLAGS.device_prefetch_size,
host_to_device_transfer_threads=FLAGS.host_to_device_transfer_threads)

writer = None
if xm.is_master_ordinal():
Expand Down

0 comments on commit 424d8c8

Please sign in to comment.