Skip to content

Commit

Permalink
Remove keep_input_mutations (huggingface#1492)
Browse files Browse the repository at this point in the history
  • Loading branch information
astachowiczhabana authored and HolyFalafel committed Nov 26, 2024
1 parent 8e87f6f commit c23dc6c
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,11 @@ def patch_scoped_linear_all_reduce(model):

def get_torch_compiled_model(model):
if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]:
model.transformer = torch.compile(
model.transformer, backend="hpu_backend", options={"keep_input_mutations": True}
)
model.transformer = torch.compile(model.transformer, backend="hpu_backend")
elif model.config.model_type in ["gpt_neox"]:
model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True})
model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend")
else:
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
model.model = torch.compile(model.model, backend="hpu_backend")
return model


Expand Down

0 comments on commit c23dc6c

Please sign in to comment.