-
Notifications
You must be signed in to change notification settings - Fork 848
/
Copy pathdlrm_main.py
729 lines (666 loc) · 25.2 KB
/
dlrm_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import itertools
import os
import sys
from dataclasses import dataclass, field
from enum import Enum
from typing import Iterator, List, Optional
import torch
import torchmetrics as metrics
from pyre_extensions import none_throws
from torch import distributed as dist
from torch.utils.data import DataLoader
from torchrec import EmbeddingBagCollection
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES
from torchrec.distributed import TrainPipelineSparseDist
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.model_parallel import (
DistributedModelParallel,
get_default_sharders,
)
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.planner.storage_reservations import (
HeuristicalStorageReservation,
)
from torchrec.models.dlrm import DLRM, DLRM_DCN, DLRM_Projection, DLRMTrain
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper
from torchrec.optim.optimizers import in_backward_optimizer_filter
from tqdm import tqdm
# OSS import
try:
# pyre-ignore[21]
# @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm/data:dlrm_dataloader
from data.dlrm_dataloader import get_dataloader
# pyre-ignore[21]
# @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm:lr_scheduler
from lr_scheduler import LRPolicyScheduler
# pyre-ignore[21]
# @manual=//ai_codesign/benchmarks/dlrm/torchrec_dlrm:multi_hot
from multi_hot import Multihot, RestartableMap
except ImportError:
pass
# internal import
try:
from .data.dlrm_dataloader import get_dataloader # noqa F811
from .lr_scheduler import LRPolicyScheduler # noqa F811
from .multi_hot import Multihot, RestartableMap # noqa F811
except ImportError:
pass
TRAIN_PIPELINE_STAGES = 3 # Number of stages in TrainPipelineSparseDist.
class InteractionType(Enum):
ORIGINAL = "original"
DCN = "dcn"
PROJECTION = "projection"
def __str__(self):
return self.value
def parse_args(argv: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="torchrec dlrm example trainer")
parser.add_argument(
"--epochs",
type=int,
default=1,
help="number of epochs to train",
)
parser.add_argument(
"--batch_size",
type=int,
default=32,
help="batch size to use for training",
)
parser.add_argument(
"--drop_last_training_batch",
dest="drop_last_training_batch",
action="store_true",
help="Drop the last non-full training batch",
)
parser.add_argument(
"--test_batch_size",
type=int,
default=None,
help="batch size to use for validation and testing",
)
parser.add_argument(
"--limit_train_batches",
type=int,
default=None,
help="number of train batches",
)
parser.add_argument(
"--limit_val_batches",
type=int,
default=None,
help="number of validation batches",
)
parser.add_argument(
"--limit_test_batches",
type=int,
default=None,
help="number of test batches",
)
parser.add_argument(
"--dataset_name",
type=str,
choices=["criteo_1t", "criteo_kaggle"],
default="criteo_1t",
help="dataset for experiment, current support criteo_1tb, criteo_kaggle",
)
parser.add_argument(
"--num_embeddings",
type=int,
default=100_000,
help="max_ind_size. The number of embeddings in each embedding table. Defaults"
" to 100_000 if num_embeddings_per_feature is not supplied.",
)
parser.add_argument(
"--num_embeddings_per_feature",
type=str,
default=None,
help="Comma separated max_ind_size per sparse feature. The number of embeddings"
" in each embedding table. 26 values are expected for the Criteo dataset.",
)
parser.add_argument(
"--dense_arch_layer_sizes",
type=str,
default="512,256,64",
help="Comma separated layer sizes for dense arch.",
)
parser.add_argument(
"--over_arch_layer_sizes",
type=str,
default="512,512,256,1",
help="Comma separated layer sizes for over arch.",
)
parser.add_argument(
"--embedding_dim",
type=int,
default=64,
help="Size of each embedding.",
)
parser.add_argument(
"--interaction_branch1_layer_sizes",
type=str,
default="2048,2048",
help="Comma separated layer sizes for interaction branch1 (only on dlrm with projection).",
)
parser.add_argument(
"--interaction_branch2_layer_sizes",
type=str,
default="2048,2048",
help="Comma separated layer sizes for interaction branch2 (only on dlrm with projection).",
)
parser.add_argument(
"--dcn_num_layers",
type=int,
default=3,
help="Number of DCN layers in interaction layer (only on dlrm with DCN).",
)
parser.add_argument(
"--dcn_low_rank_dim",
type=int,
default=512,
help="Low rank dimension for DCN in interaction layer (only on dlrm with DCN).",
)
parser.add_argument(
"--undersampling_rate",
type=float,
help="Desired proportion of zero-labeled samples to retain (i.e. undersampling zero-labeled rows)."
" Ex. 0.3 indicates only 30pct of the rows with label 0 will be kept."
" All rows with label 1 will be kept. Value should be between 0 and 1."
" When not supplied, no undersampling occurs.",
)
parser.add_argument(
"--seed",
type=int,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--pin_memory",
dest="pin_memory",
action="store_true",
help="Use pinned memory when loading data.",
)
parser.add_argument(
"--mmap_mode",
dest="mmap_mode",
action="store_true",
help="--mmap_mode mmaps the dataset."
" That is, the dataset is kept on disk but is accessed as if it were in memory."
" --mmap_mode is intended mostly for faster debugging. Use --mmap_mode to bypass"
" preloading the dataset when preloading takes too long or when there is "
" insufficient memory available to load the full dataset.",
)
parser.add_argument(
"--in_memory_binary_criteo_path",
type=str,
default=None,
help="Directory path containing the Criteo dataset npy files.",
)
parser.add_argument(
"--synthetic_multi_hot_criteo_path",
type=str,
default=None,
help="Directory path containing the MLPerf v2 synthetic multi-hot dataset npz files.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=15.0,
help="Learning rate.",
)
parser.add_argument(
"--eps",
type=float,
default=1e-8,
help="Epsilon for Adagrad optimizer.",
)
parser.add_argument(
"--shuffle_batches",
dest="shuffle_batches",
action="store_true",
help="Shuffle each batch during training.",
)
parser.add_argument(
"--shuffle_training_set",
dest="shuffle_training_set",
action="store_true",
help="Shuffle the training set in memory. This will override mmap_mode",
)
parser.add_argument(
"--validation_freq_within_epoch",
type=int,
default=None,
help="Frequency at which validation will be run within an epoch.",
)
parser.set_defaults(
pin_memory=None,
mmap_mode=None,
drop_last=None,
shuffle_batches=None,
shuffle_training_set=None,
)
parser.add_argument(
"--adagrad",
dest="adagrad",
action="store_true",
help="Flag to determine if adagrad optimizer should be used.",
)
parser.add_argument(
"--interaction_type",
type=InteractionType,
choices=list(InteractionType),
default=InteractionType.ORIGINAL,
help="Determine the interaction type to be used (original, dcn, or projection)"
" default is original DLRM with pairwise dot product",
)
parser.add_argument(
"--collect_multi_hot_freqs_stats",
dest="collect_multi_hot_freqs_stats",
action="store_true",
help="Flag to determine whether to collect stats on freq of embedding access.",
)
parser.add_argument(
"--multi_hot_sizes",
type=str,
default=None,
help="Comma separated multihot size per sparse feature. 26 values are expected for the Criteo dataset.",
)
parser.add_argument(
"--multi_hot_distribution_type",
type=str,
choices=["uniform", "pareto"],
default=None,
help="Multi-hot distribution options.",
)
parser.add_argument("--lr_warmup_steps", type=int, default=0)
parser.add_argument("--lr_decay_start", type=int, default=0)
parser.add_argument("--lr_decay_steps", type=int, default=0)
parser.add_argument(
"--print_lr",
action="store_true",
help="Print learning rate every iteration.",
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help="Enable TensorFloat-32 mode for matrix multiplications on A100 (or newer) GPUs.",
)
parser.add_argument(
"--print_sharding_plan",
action="store_true",
help="Print the sharding plan used for each embedding table.",
)
return parser.parse_args(argv)
def _evaluate(
limit_batches: Optional[int],
pipeline: TrainPipelineSparseDist,
eval_dataloader: DataLoader,
stage: str,
) -> float:
"""
Evaluates model. Computes and prints AUROC. Helper function for train_val_test.
Args:
limit_batches (Optional[int]): Limits the dataloader to the first `limit_batches` batches.
pipeline (TrainPipelineSparseDist): data pipeline.
eval_dataloader (DataLoader): Dataloader for either the validation set or test set.
stage (str): "val" or "test".
Returns:
float: auroc result
"""
pipeline._model.eval()
device = pipeline._device
iterator = itertools.islice(iter(eval_dataloader), limit_batches)
auroc = metrics.AUROC(compute_on_step=False, num_classes=2).to(device)
is_rank_zero = dist.get_rank() == 0
if is_rank_zero:
pbar = tqdm(
iter(int, 1),
desc=f"Evaluating {stage} set",
total=len(eval_dataloader),
disable=False,
)
with torch.no_grad():
while True:
try:
_loss, logits, labels = pipeline.progress(iterator)
preds = torch.sigmoid(logits)
auroc(preds, labels)
if is_rank_zero:
pbar.update(1)
except StopIteration:
break
auroc_result = auroc.compute().item()
num_samples = torch.tensor(sum(map(len, auroc.target)), device=device)
dist.reduce(num_samples, 0, op=dist.ReduceOp.SUM)
if is_rank_zero:
print(f"AUROC over {stage} set: {auroc_result}.")
print(f"Number of {stage} samples: {num_samples}")
return auroc_result
def batched(it: Iterator, n: int):
assert n >= 1
for x in it:
yield itertools.chain((x,), itertools.islice(it, n - 1))
def _train(
pipeline: TrainPipelineSparseDist,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
epoch: int,
lr_scheduler,
print_lr: bool,
validation_freq: Optional[int],
limit_train_batches: Optional[int],
limit_val_batches: Optional[int],
) -> None:
"""
Trains model for 1 epoch. Helper function for train_val_test.
Args:
pipeline (TrainPipelineSparseDist): data pipeline.
train_dataloader (DataLoader): Training set's dataloader.
val_dataloader (DataLoader): Validation set's dataloader.
epoch (int): The number of complete passes through the training set so far.
lr_scheduler (LRPolicyScheduler): Learning rate scheduler.
print_lr (bool): Whether to print the learning rate every training step.
validation_freq (Optional[int]): The number of training steps between validation runs within an epoch.
limit_train_batches (Optional[int]): Limits the training set to the first `limit_train_batches` batches.
limit_val_batches (Optional[int]): Limits the validation set to the first `limit_val_batches` batches.
Returns:
None.
"""
pipeline._model.train()
iterator = itertools.islice(iter(train_dataloader), limit_train_batches)
is_rank_zero = dist.get_rank() == 0
if is_rank_zero:
pbar = tqdm(
iter(int, 1),
desc=f"Epoch {epoch}",
total=len(train_dataloader),
disable=False,
)
start_it = 0
n = (
validation_freq
if validation_freq
else limit_train_batches
if limit_train_batches
else len(train_dataloader)
)
for batched_iterator in batched(iterator, n):
for it in itertools.count(start_it):
try:
if is_rank_zero and print_lr:
for i, g in enumerate(pipeline._optimizer.param_groups):
print(f"lr: {it} {i} {g['lr']:.6f}")
pipeline.progress(batched_iterator)
lr_scheduler.step()
if is_rank_zero:
pbar.update(1)
except StopIteration:
if is_rank_zero:
print("Total number of iterations:", it)
start_it = it
break
if validation_freq and start_it % validation_freq == 0:
_evaluate(limit_val_batches, pipeline, val_dataloader, "val")
pipeline._model.train()
@dataclass
class TrainValTestResults:
val_aurocs: List[float] = field(default_factory=list)
test_auroc: Optional[float] = None
def train_val_test(
args: argparse.Namespace,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
test_dataloader: DataLoader,
lr_scheduler: LRPolicyScheduler,
) -> TrainValTestResults:
"""
Train/validation/test loop.
Args:
args (argparse.Namespace): parsed command line args.
model (torch.nn.Module): model to train.
optimizer (torch.optim.Optimizer): optimizer to use.
device (torch.device): device to use.
train_dataloader (DataLoader): Training set's dataloader.
val_dataloader (DataLoader): Validation set's dataloader.
test_dataloader (DataLoader): Test set's dataloader.
lr_scheduler (LRPolicyScheduler): Learning rate scheduler.
Returns:
TrainValTestResults.
"""
results = TrainValTestResults()
pipeline = TrainPipelineSparseDist(
model, optimizer, device, execute_all_batches=True
)
for epoch in range(args.epochs):
_train(
pipeline,
train_dataloader,
val_dataloader,
epoch,
lr_scheduler,
args.print_lr,
args.validation_freq_within_epoch,
args.limit_train_batches,
args.limit_val_batches,
)
val_auroc = _evaluate(args.limit_val_batches, pipeline, val_dataloader, "val")
results.val_aurocs.append(val_auroc)
test_auroc = _evaluate(args.limit_test_batches, pipeline, test_dataloader, "test")
results.test_auroc = test_auroc
return results
def main(argv: List[str]) -> None:
"""
Trains, validates, and tests a Deep Learning Recommendation Model (DLRM)
(https://arxiv.org/abs/1906.00091). The DLRM model contains both data parallel
components (e.g. multi-layer perceptrons & interaction arch) and model parallel
components (e.g. embedding tables). The DLRM model is pipelined so that dataloading,
data-parallel to model-parallel comms, and forward/backward are overlapped. Can be
run with either a random dataloader or an in-memory Criteo 1 TB click logs dataset
(https://ailab.criteo.com/download-criteo-1tb-click-logs-dataset/).
Args:
argv (List[str]): command line args.
Returns:
None.
"""
args = parse_args(argv)
for name, val in vars(args).items():
try:
vars(args)[name] = list(map(int, val.split(",")))
except (ValueError, AttributeError):
pass
torch.backends.cuda.matmul.allow_tf32 = args.allow_tf32
if args.multi_hot_sizes is not None:
assert (
args.num_embeddings_per_feature is not None
and len(args.multi_hot_sizes) == len(args.num_embeddings_per_feature)
or args.num_embeddings_per_feature is None
and len(args.multi_hot_sizes) == len(DEFAULT_CAT_NAMES)
), "--multi_hot_sizes must be a comma delimited list the same size as the number of embedding tables."
assert (
args.in_memory_binary_criteo_path is None
or args.synthetic_multi_hot_criteo_path is None
), "--in_memory_binary_criteo_path and --synthetic_multi_hot_criteo_path are mutually exclusive CLI arguments."
assert (
args.multi_hot_sizes is None or args.synthetic_multi_hot_criteo_path is None
), "--multi_hot_sizes is used to convert 1-hot to multi-hot. It's inapplicable with --synthetic_multi_hot_criteo_path."
assert (
args.multi_hot_distribution_type is None
or args.synthetic_multi_hot_criteo_path is None
), "--multi_hot_distribution_type is used to convert 1-hot to multi-hot. It's inapplicable with --synthetic_multi_hot_criteo_path."
rank = int(os.environ["LOCAL_RANK"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device: torch.device = torch.device("cpu")
backend = "gloo"
if rank == 0:
print(
"PARAMS: (lr, batch_size, warmup_steps, decay_start, decay_steps): "
f"{(args.learning_rate, args.batch_size, args.lr_warmup_steps, args.lr_decay_start, args.lr_decay_steps)}"
)
dist.init_process_group(backend=backend)
if args.num_embeddings_per_feature is not None:
args.num_embeddings = None
# Sets default limits for random dataloader iterations when left unspecified.
if (
args.in_memory_binary_criteo_path
is args.synthetic_multi_hot_criteo_path
is None
):
for split in ["train", "val", "test"]:
attr = f"limit_{split}_batches"
if getattr(args, attr) is None:
setattr(args, attr, 10)
train_dataloader = get_dataloader(args, backend, "train")
val_dataloader = get_dataloader(args, backend, "val")
test_dataloader = get_dataloader(args, backend, "test")
eb_configs = [
EmbeddingBagConfig(
name=f"t_{feature_name}",
embedding_dim=args.embedding_dim,
num_embeddings=(
none_throws(args.num_embeddings_per_feature)[feature_idx]
if args.num_embeddings is None
else args.num_embeddings
),
feature_names=[feature_name],
)
for feature_idx, feature_name in enumerate(DEFAULT_CAT_NAMES)
]
sharded_module_kwargs = {}
if args.over_arch_layer_sizes is not None:
sharded_module_kwargs["over_arch_layer_sizes"] = args.over_arch_layer_sizes
if args.interaction_type == InteractionType.ORIGINAL:
dlrm_model = DLRM(
embedding_bag_collection=EmbeddingBagCollection(
tables=eb_configs, device=torch.device("meta")
),
dense_in_features=len(DEFAULT_INT_NAMES),
dense_arch_layer_sizes=args.dense_arch_layer_sizes,
over_arch_layer_sizes=args.over_arch_layer_sizes,
dense_device=device,
)
elif args.interaction_type == InteractionType.DCN:
dlrm_model = DLRM_DCN(
embedding_bag_collection=EmbeddingBagCollection(
tables=eb_configs, device=torch.device("meta")
),
dense_in_features=len(DEFAULT_INT_NAMES),
dense_arch_layer_sizes=args.dense_arch_layer_sizes,
over_arch_layer_sizes=args.over_arch_layer_sizes,
dcn_num_layers=args.dcn_num_layers,
dcn_low_rank_dim=args.dcn_low_rank_dim,
dense_device=device,
)
elif args.interaction_type == InteractionType.PROJECTION:
dlrm_model = DLRM_Projection(
embedding_bag_collection=EmbeddingBagCollection(
tables=eb_configs, device=torch.device("meta")
),
dense_in_features=len(DEFAULT_INT_NAMES),
dense_arch_layer_sizes=args.dense_arch_layer_sizes,
over_arch_layer_sizes=args.over_arch_layer_sizes,
interaction_branch1_layer_sizes=args.interaction_branch1_layer_sizes,
interaction_branch2_layer_sizes=args.interaction_branch2_layer_sizes,
dense_device=device,
)
else:
raise ValueError(
"Unknown interaction option set. Should be original, dcn, or projection."
)
train_model = DLRMTrain(dlrm_model)
embedding_optimizer = torch.optim.Adagrad if args.adagrad else torch.optim.SGD
# This will apply the Adagrad optimizer in the backward pass for the embeddings (sparse_arch). This means that
# the optimizer update will be applied in the backward pass, in this case through a fused op.
# TorchRec will use the FBGEMM implementation of EXACT_ADAGRAD. For GPU devices, a fused CUDA kernel is invoked. For CPU, FBGEMM_GPU invokes CPU kernels
# https://github.com/pytorch/FBGEMM/blob/2cb8b0dff3e67f9a009c4299defbd6b99cc12b8f/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py#L676-L678
# Note that lr_decay, weight_decay and initial_accumulator_value for Adagrad optimizer in FBGEMM v0.3.2
# cannot be specified below. This equivalently means that all these parameters are hardcoded to zero.
optimizer_kwargs = {"lr": args.learning_rate}
if args.adagrad:
optimizer_kwargs["eps"] = args.eps
apply_optimizer_in_backward(
embedding_optimizer,
train_model.model.sparse_arch.parameters(),
optimizer_kwargs,
)
planner = EmbeddingShardingPlanner(
topology=Topology(
local_world_size=get_local_size(),
world_size=dist.get_world_size(),
compute_device=device.type,
),
batch_size=args.batch_size,
# If experience OOM, increase the percentage. see
# https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation
storage_reservation=HeuristicalStorageReservation(percentage=0.05),
)
plan = planner.collective_plan(
train_model, get_default_sharders(), dist.GroupMember.WORLD
)
model = DistributedModelParallel(
module=train_model,
device=device,
plan=plan,
)
if rank == 0 and args.print_sharding_plan:
for collectionkey, plans in model._plan.plan.items():
print(collectionkey)
for table_name, plan in plans.items():
print(table_name, "\n", plan, "\n")
def optimizer_with_params():
if args.adagrad:
return lambda params: torch.optim.Adagrad(
params, lr=args.learning_rate, eps=args.eps
)
else:
return lambda params: torch.optim.SGD(params, lr=args.learning_rate)
dense_optimizer = KeyedOptimizerWrapper(
dict(in_backward_optimizer_filter(model.named_parameters())),
optimizer_with_params(),
)
optimizer = CombinedOptimizer([model.fused_optimizer, dense_optimizer])
lr_scheduler = LRPolicyScheduler(
optimizer, args.lr_warmup_steps, args.lr_decay_start, args.lr_decay_steps
)
if args.multi_hot_sizes is not None:
multihot = Multihot(
args.multi_hot_sizes,
args.num_embeddings_per_feature,
args.batch_size,
collect_freqs_stats=args.collect_multi_hot_freqs_stats,
dist_type=args.multi_hot_distribution_type,
)
multihot.pause_stats_collection_during_val_and_test(model)
train_dataloader = RestartableMap(
multihot.convert_to_multi_hot, train_dataloader
)
val_dataloader = RestartableMap(multihot.convert_to_multi_hot, val_dataloader)
test_dataloader = RestartableMap(multihot.convert_to_multi_hot, test_dataloader)
train_val_test(
args,
model,
optimizer,
device,
train_dataloader,
val_dataloader,
test_dataloader,
lr_scheduler,
)
if args.collect_multi_hot_freqs_stats:
multihot.save_freqs_stats()
def invoke_main() -> None:
main(sys.argv[1:])
if __name__ == "__main__":
invoke_main() # pragma: no cover