diff --git a/extract_features_pytorch.py b/extract_features_pytorch.py index 53a91ae48f7f..0c7b6b8bd99f 100644 --- a/extract_features_pytorch.py +++ b/extract_features_pytorch.py @@ -249,6 +249,9 @@ def main(): if args.init_checkpoint is not None: model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) + + if n_gpu > 1: + model = nn.DataParallel(model) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 5d283d3415a5..a9766b9df8bd 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -482,6 +482,9 @@ def main(): if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) + + if n_gpu > 1: + model = torch.nn.DataParallel(model) optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01}, {'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.} diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 626759a08575..467931f68a58 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -795,6 +795,9 @@ def main(): if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) + + if n_gpu > 1: + model = torch.nn.DataParallel(model) optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01}, {'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}