Skip to content

Commit

Permalink
Move dynamo.optimize to the end of model preparation (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymwangg authored Feb 28, 2023
1 parent fdb1402 commit 639c1da
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,10 +1102,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement=None):
self._models.append(model)
if device_placement:
model = model.to(self.device)
if self.state.dynamo_backend != DynamoBackend.NO:
import torch._dynamo as dynamo

model = dynamo.optimize(self.state.dynamo_backend.value.lower())(model)
if self.distributed_type == DistributedType.MULTI_GPU:
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
Expand Down Expand Up @@ -1146,6 +1142,11 @@ def prepare_model(self, model: torch.nn.Module, device_placement=None):
model.forward = convert_outputs_to_fp32(model.forward)
if self.distributed_type == DistributedType.TPU and self.state.fork_launched:
model = xmp.MpModelWrapper(model).to(self.device)
# torch.compile should be called last.
if self.state.dynamo_backend != DynamoBackend.NO:
import torch._dynamo as dynamo

model = dynamo.optimize(self.state.dynamo_backend.value.lower())(model)
return model

def _prepare_deepspeed(self, *args):
Expand Down

0 comments on commit 639c1da

Please sign in to comment.