Skip to content

Commit

Permalink
Text-generation, model set-up: torch.compile for attributes instead o…
Browse files Browse the repository at this point in the history
…f models' types (huggingface#1452)
  • Loading branch information
dsmertin authored and Liangyx2 committed Jan 20, 2025
1 parent 27a5da5 commit a082ce3
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,23 @@ def patch_scoped_linear_all_reduce(model):
patch_scoped_linear_all_reduce(module)


def get_torch_compiled_model(model):
if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]:
model.transformer = torch.compile(model.transformer, backend="hpu_backend")
elif model.config.model_type in ["gpt_neox"]:
model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend")
def get_torch_compiled_model(model, logger):
# for gpt_bigcode, mpt, bloom, gpt2 model_type
if hasattr(model, "transformer"):
model.transformer = torch.compile(
model.transformer, backend="hpu_backend", options={"keep_input_mutations": True}
)
# for gpt_neox
elif hasattr(model, "gpt_neox"):
model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True})
# for llama, mistral, mixtral, qwen2
elif hasattr(model, "model"):
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
else:
model.model = torch.compile(model.model, backend="hpu_backend")
logger.warning(
"In low performance case, please explicitly specify a module you want to wrap with `torch.compile`"
)
model = torch.compile(model, backend="hpu_backend", options={"keep_input_mutations": True})
return model


Expand Down Expand Up @@ -306,9 +316,9 @@ def setup_model(args, model_dtype, model_kwargs, logger):
model.base_model.model = wrap_in_hpu_graph(model.base_model.model)

if args.torch_compile:
model = get_torch_compiled_model(model)
model = get_torch_compiled_model(model, logger)
# if args.assistant_model is not None:
# assistant_model = get_torch_compiled_model(assistant_model)
# assistant_model = get_torch_compiled_model(assistant_model, logger)
return model, assistant_model


Expand Down Expand Up @@ -373,7 +383,7 @@ def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_di
model = wrap_in_hpu_graph(model)

if args.torch_compile:
model = get_torch_compiled_model(model)
model = get_torch_compiled_model(model, logger)

return model, args.assistant_model

Expand Down Expand Up @@ -447,9 +457,9 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):
model = setup_quantization(model, args)

if args.torch_compile:
model = get_torch_compiled_model(model)
model = get_torch_compiled_model(model, logger)
# if args.assistant_model is not None:
# assistant_model = get_torch_compiled_model(assistant_model)
# assistant_model = get_torch_compiled_model(assistant_model, logger)
return model, assistant_model


Expand Down

0 comments on commit a082ce3

Please sign in to comment.