Skip to content

Commit

Permalink
Merge pull request huggingface#1 from huggingface/multi-gpu-support
Browse files Browse the repository at this point in the history
Create DataParallel model if several GPUs
  • Loading branch information
VictorSanh authored Nov 3, 2018
2 parents 6165f84 + 20294bd commit f6ed6ac
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 0 deletions.
3 changes: 3 additions & 0 deletions extract_features_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def main():
if args.init_checkpoint is not None:
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)

if n_gpu > 1:
model = nn.DataParallel(model)

all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
Expand Down
3 changes: 3 additions & 0 deletions run_classifier_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ def main():
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)

if n_gpu > 1:
model = torch.nn.DataParallel(model)

optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
Expand Down
3 changes: 3 additions & 0 deletions run_squad_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,9 @@ def main():
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)

if n_gpu > 1:
model = torch.nn.DataParallel(model)

optimizer = BERTAdam([{'params': [p for n, p in model.named_parameters() if n != 'bias'], 'l2': 0.01},
{'params': [p for n, p in model.named_parameters() if n == 'bias'], 'l2': 0.}
Expand Down

0 comments on commit f6ed6ac

Please sign in to comment.