From 39e827941acb010f75f2eb5ae410ea3b0ac40218 Mon Sep 17 00:00:00 2001 From: dsmertin Date: Thu, 17 Oct 2024 15:08:31 +0000 Subject: [PATCH 1/2] torch.compile for attributes instead of models type --- examples/text-generation/utils.py | 36 +++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 450ef1d643..f969379bc5 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -177,13 +177,27 @@ def patch_scoped_linear_all_reduce(model): patch_scoped_linear_all_reduce(module) -def get_torch_compiled_model(model): - if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]: - model.transformer = torch.compile(model.transformer, backend="hpu_backend") - elif model.config.model_type in ["gpt_neox"]: - model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend") +def get_torch_compiled_model(model, logger): + # for gpt_bigcode, mpt, bloom, gpt2 model_type + if hasattr(model, 'transformer'): + model.transformer = torch.compile( + model.transformer, backend="hpu_backend", options={"keep_input_mutations": True} + ) + # for gpt_neox + elif hasattr(model, 'gpt_neox'): + model.gpt_neox = torch.compile( + model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True} + ) + # for llama, mistral, mixtral, qwen2 + elif hasattr(model, 'model'): + model.model = torch.compile( + model.model, backend="hpu_backend", options={"keep_input_mutations": True} + ) else: - model.model = torch.compile(model.model, backend="hpu_backend") + logger.warning( + "In low performance case, please explicitly specify a module you want to wrap with `torch.compile`" + ) + model = torch.compile(model, backend="hpu_backend", options={"keep_input_mutations": True}) return model @@ -305,9 +319,9 @@ def setup_model(args, model_dtype, model_kwargs, logger): model.base_model.model = wrap_in_hpu_graph(model.base_model.model) if args.torch_compile: - model = get_torch_compiled_model(model) + model = get_torch_compiled_model(model, logger) # if args.assistant_model is not None: - # assistant_model = get_torch_compiled_model(assistant_model) + # assistant_model = get_torch_compiled_model(assistant_model, logger) return model, assistant_model @@ -372,7 +386,7 @@ def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_di model = wrap_in_hpu_graph(model) if args.torch_compile: - model = get_torch_compiled_model(model) + model = get_torch_compiled_model(model, logger) return model, args.assistant_model @@ -446,9 +460,9 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = setup_quantization(model, args) if args.torch_compile: - model = get_torch_compiled_model(model) + model = get_torch_compiled_model(model, logger) # if args.assistant_model is not None: - # assistant_model = get_torch_compiled_model(assistant_model) + # assistant_model = get_torch_compiled_model(assistant_model, logger) return model, assistant_model From 3a2b9e660c0b1b35398b4883eacc0d4b3e5147ef Mon Sep 17 00:00:00 2001 From: dsmertin Date: Fri, 29 Nov 2024 08:39:21 +0000 Subject: [PATCH 2/2] added style --- examples/text-generation/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index f969379bc5..1a8ac05176 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -179,20 +179,16 @@ def patch_scoped_linear_all_reduce(model): def get_torch_compiled_model(model, logger): # for gpt_bigcode, mpt, bloom, gpt2 model_type - if hasattr(model, 'transformer'): + if hasattr(model, "transformer"): model.transformer = torch.compile( model.transformer, backend="hpu_backend", options={"keep_input_mutations": True} ) # for gpt_neox - elif hasattr(model, 'gpt_neox'): - model.gpt_neox = torch.compile( - model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True} - ) + elif hasattr(model, "gpt_neox"): + model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True}) # for llama, mistral, mixtral, qwen2 - elif hasattr(model, 'model'): - model.model = torch.compile( - model.model, backend="hpu_backend", options={"keep_input_mutations": True} - ) + elif hasattr(model, "model"): + model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) else: logger.warning( "In low performance case, please explicitly specify a module you want to wrap with `torch.compile`"