diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 698a67ec5a..3f76170f90 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -178,13 +178,11 @@ def patch_scoped_linear_all_reduce(model): def get_torch_compiled_model(model): if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]: - model.transformer = torch.compile( - model.transformer, backend="hpu_backend", options={"keep_input_mutations": True} - ) + model.transformer = torch.compile(model.transformer, backend="hpu_backend") elif model.config.model_type in ["gpt_neox"]: - model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True}) + model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend") else: - model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) + model.model = torch.compile(model.model, backend="hpu_backend") return model