From 9ae0251596891e874601f725e03d272fac660b01 Mon Sep 17 00:00:00 2001 From: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com> Date: Wed, 20 Nov 2024 22:46:33 +0100 Subject: [PATCH] Remove keep_input_mutations (#1492) --- examples/text-generation/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index ed52e2ff6d..92968d72e5 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -178,13 +178,11 @@ def patch_scoped_linear_all_reduce(model): def get_torch_compiled_model(model): if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]: - model.transformer = torch.compile( - model.transformer, backend="hpu_backend", options={"keep_input_mutations": True} - ) + model.transformer = torch.compile(model.transformer, backend="hpu_backend") elif model.config.model_type in ["gpt_neox"]: - model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True}) + model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend") else: - model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) + model.model = torch.compile(model.model, backend="hpu_backend") return model