Skip to content

Commit

Permalink
GPT2 torch.compile fix (#1434)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmertin authored and regisss committed Oct 18, 2024
1 parent 190a29a commit 529bda8
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def patch_scoped_linear_all_reduce(model):


def get_torch_compiled_model(model):
if model.config.model_type in ["gpt_bigcode", "mpt", "bloom"]:
if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]:
model.transformer = torch.compile(
model.transformer, backend="hpu_backend", options={"keep_input_mutations": True}
)
Expand Down Expand Up @@ -245,12 +245,14 @@ def setup_model(args, model_dtype, model_kwargs, logger):
args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs
)
elif args.load_quantized_model_with_inc:
#TODO: This will be removed in v1.19 Synapse release
#Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release.
# TODO: This will be removed in v1.19 Synapse release
# Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release.
import neural_compressor.torch.algorithms.weight_only.save_load as nc_sl

nc_sl.WOQModelLoader._load_remaining_pretrained_weight = local_load_remaining_pretrained_weight

from neural_compressor.torch.quantization import load

model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs)
elif args.local_quantized_inc_model_path:
org_model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -667,9 +669,10 @@ def initialize_model(args, logger):
logger.info(f"Model initialization took {(init_end - init_start):.3f}s")
return model, assistant_model, tokenizer, generation_config

#TODO:This will be removed from Synapse v1.19 release.
#This is to override _load_remaining_pretrained_weight for Transformer 4.45 release.
def local_load_remaining_pretrained_weight(self,model):

# TODO:This will be removed from Synapse v1.19 release.
# This is to override _load_remaining_pretrained_weight for Transformer 4.45 release.
def local_load_remaining_pretrained_weight(self, model):
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict

resolved_archive_file = self.kwargs.pop("resolved_archive_file", None)
Expand All @@ -687,7 +690,7 @@ def local_load_remaining_pretrained_weight(self,model):
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)

params_dict={
params_dict = {
"model": model,
"state_dict": state_dict,
"start_prefix": "",
Expand Down

0 comments on commit 529bda8

Please sign in to comment.