From 529bda8367d1028a0624480e1e07f083778be63b Mon Sep 17 00:00:00 2001 From: Dmitry Date: Fri, 18 Oct 2024 10:42:49 +0200 Subject: [PATCH] GPT2 torch.compile fix (#1434) --- examples/text-generation/utils.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c3a270c86a..2b0fe0d328 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -177,7 +177,7 @@ def patch_scoped_linear_all_reduce(model): def get_torch_compiled_model(model): - if model.config.model_type in ["gpt_bigcode", "mpt", "bloom"]: + if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]: model.transformer = torch.compile( model.transformer, backend="hpu_backend", options={"keep_input_mutations": True} ) @@ -245,12 +245,14 @@ def setup_model(args, model_dtype, model_kwargs, logger): args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs ) elif args.load_quantized_model_with_inc: - #TODO: This will be removed in v1.19 Synapse release - #Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release. + # TODO: This will be removed in v1.19 Synapse release + # Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release. import neural_compressor.torch.algorithms.weight_only.save_load as nc_sl + nc_sl.WOQModelLoader._load_remaining_pretrained_weight = local_load_remaining_pretrained_weight from neural_compressor.torch.quantization import load + model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs) elif args.local_quantized_inc_model_path: org_model = AutoModelForCausalLM.from_pretrained( @@ -667,9 +669,10 @@ def initialize_model(args, logger): logger.info(f"Model initialization took {(init_end - init_start):.3f}s") return model, assistant_model, tokenizer, generation_config -#TODO:This will be removed from Synapse v1.19 release. -#This is to override _load_remaining_pretrained_weight for Transformer 4.45 release. -def local_load_remaining_pretrained_weight(self,model): + +# TODO:This will be removed from Synapse v1.19 release. +# This is to override _load_remaining_pretrained_weight for Transformer 4.45 release. +def local_load_remaining_pretrained_weight(self, model): from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict resolved_archive_file = self.kwargs.pop("resolved_archive_file", None) @@ -687,7 +690,7 @@ def local_load_remaining_pretrained_weight(self,model): for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - params_dict={ + params_dict = { "model": model, "state_dict": state_dict, "start_prefix": "",