Skip to content

Commit

Permalink
torch.compile for attributes instead of models type
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmertin committed Oct 23, 2024
1 parent 03fa6dd commit aa268c0
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,27 @@ def patch_scoped_linear_all_reduce(model):
patch_scoped_linear_all_reduce(module)


def get_torch_compiled_model(model):
if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]:
def get_torch_compiled_model(model, logger):
# for gpt_bigcode, mpt, bloom, gpt2 model_type
if hasattr(model, 'transformer'):
model.transformer = torch.compile(
model.transformer, backend="hpu_backend", options={"keep_input_mutations": True}
)
elif model.config.model_type in ["gpt_neox"]:
model.gpt_neox = torch.compile(model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True})
# for gpt_neox
elif hasattr(model, 'gpt_neox'):
model.gpt_neox = torch.compile(
model.gpt_neox, backend="hpu_backend", options={"keep_input_mutations": True}
)
# for llama, mistral, mixtral, qwen2
elif hasattr(model, 'model'):
model.model = torch.compile(
model.model, backend="hpu_backend", options={"keep_input_mutations": True}
)
else:
model.model = torch.compile(model.model, backend="hpu_backend", options={"keep_input_mutations": True})
logger.warning(
"in low performance case, please explicitly specify a module you want wrap with `torch.compile`"
)
model = torch.compile(model, backend="hpu_backend", options={"keep_input_mutations": True})
return model


Expand Down Expand Up @@ -304,9 +316,9 @@ def setup_model(args, model_dtype, model_kwargs, logger):
model.base_model.model = wrap_in_hpu_graph(model.base_model.model)

if args.torch_compile:
model = get_torch_compiled_model(model)
model = get_torch_compiled_model(model, logger)
# if args.assistant_model is not None:
# assistant_model = get_torch_compiled_model(assistant_model)
# assistant_model = get_torch_compiled_model(assistant_model, logger)
return model, assistant_model


Expand Down Expand Up @@ -371,7 +383,7 @@ def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_di
model = wrap_in_hpu_graph(model)

if args.torch_compile:
model = get_torch_compiled_model(model)
model = get_torch_compiled_model(model, logger)

return model, args.assistant_model

Expand Down Expand Up @@ -445,9 +457,9 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):
model = setup_quantization(model, args)

if args.torch_compile:
model = get_torch_compiled_model(model)
model = get_torch_compiled_model(model, logger)
# if args.assistant_model is not None:
# assistant_model = get_torch_compiled_model(assistant_model)
# assistant_model = get_torch_compiled_model(assistant_model, logger)
return model, assistant_model


Expand Down

0 comments on commit aa268c0

Please sign in to comment.