From 20294bdd7e13d1fb4f3989eecdcd5ad788f1caea Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Sat, 3 Nov 2018 10:10:01 -0400 Subject: [PATCH] Create DataParallel model if several GPUs --- extract_features_pytorch.py | 3 +++ run_classifier_pytorch.py | 3 +++ run_squad_pytorch.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/extract_features_pytorch.py b/extract_features_pytorch.py index 53a91ae48f7..0c7b6b8bd99 100644 --- a/extract_features_pytorch.py +++ b/extract_features_pytorch.py @@ -249,6 +249,9 @@ def main(): if args.init_checkpoint is not None: model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) + + if n_gpu > 1: + model = nn.DataParallel(model) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) diff --git a/run_classifier_pytorch.py b/run_classifier_pytorch.py index 5d283d3415a..a9766b9df8b 100644 --- a/run_classifier_pytorch.py +++ b/run_classifier_pytorch.py @@ -482,6 +482,9 @@ def main(): if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) + + if n_gpu > 1: + model = torch.nn.DataParallel(model) optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01}, {'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.} diff --git a/run_squad_pytorch.py b/run_squad_pytorch.py index 626759a0857..467931f68a5 100644 --- a/run_squad_pytorch.py +++ b/run_squad_pytorch.py @@ -795,6 +795,9 @@ def main(): if args.init_checkpoint is not None: model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.to(device) + + if n_gpu > 1: + model = torch.nn.DataParallel(model) optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01}, {'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}