From b730db09edc60df7105294b155fbed5d8fa09e29 Mon Sep 17 00:00:00 2001 From: dsmertin Date: Thu, 17 Oct 2024 16:07:05 +0000 Subject: [PATCH] added gpt2 to the list for torch.compile with model.transformer --- examples/text-generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 66690c9b05..de8f27762c 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -177,7 +177,7 @@ def patch_scoped_linear_all_reduce(model): def get_torch_compiled_model(model): - if model.config.model_type in ["gpt_bigcode", "mpt", "bloom"]: + if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]: model.transformer = torch.compile( model.transformer, backend="hpu_backend", options={"keep_input_mutations": True} )