diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 9cc91022cd..db8abf34be 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -177,13 +177,23 @@ def patch_scoped_linear_all_reduce(model): patch_scoped_linear_all_reduce(module) -def get_torch_compiled_model(model): - if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]: - model.transformer = torch.compile(model.transformer, backend="hpu_backend") - elif model.config.model_type in ["gpt_neox"]: - model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend") +def get_torch_compiled_model(model, logger): + # for gpt_bigcode, mpt, bloom, gpt2 model_type + if hasattr(model, "transformer"): + model.transformer = torch.compile( + model.transformer, backend="hpu_backend", options={"keep_input_mutations": True} + ) + # for gpt_neox + elif hasattr(model, "gpt_neox"): + model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True}) + # for llama, mistral, mixtral, qwen2 + elif hasattr(model, "model"): + model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True}) else: - model.model = torch.compile(model.model, backend="hpu_backend") + logger.warning( + "In low performance case, please explicitly specify a module you want to wrap with `torch.compile`" + ) + model = torch.compile(model, backend="hpu_backend", options={"keep_input_mutations": True}) return model @@ -306,9 +316,9 @@ def setup_model(args, model_dtype, model_kwargs, logger): model.base_model.model = wrap_in_hpu_graph(model.base_model.model) if args.torch_compile: - model = get_torch_compiled_model(model) + model = get_torch_compiled_model(model, logger) # if args.assistant_model is not None: - # assistant_model = get_torch_compiled_model(assistant_model) + # assistant_model = get_torch_compiled_model(assistant_model, logger) return model, assistant_model @@ -373,7 +383,7 @@ def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_di model = wrap_in_hpu_graph(model) if args.torch_compile: - model = get_torch_compiled_model(model) + model = get_torch_compiled_model(model, logger) return model, args.assistant_model @@ -447,9 +457,9 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = setup_quantization(model, args) if args.torch_compile: - model = get_torch_compiled_model(model) + model = get_torch_compiled_model(model, logger) # if args.assistant_model is not None: - # assistant_model = get_torch_compiled_model(assistant_model) + # assistant_model = get_torch_compiled_model(assistant_model, logger) return model, assistant_model