Skip to content

Commit

Permalink
Fix error on 4bit checkpoint load with run_lm_eval on TF4.45.2 (#1439)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiminha authored and regisss committed Oct 18, 2024
1 parent f2d7eaa commit f98688d
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,14 @@ def setup_model(args, model_dtype, model_kwargs, logger):
args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs
)
elif args.load_quantized_model_with_inc:
#TODO: This will be removed in v1.19 Synapse release
#Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release.
# TODO: This will be removed in v1.19 Synapse release
# Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release.
import neural_compressor.torch.algorithms.weight_only.save_load as nc_sl

nc_sl.WOQModelLoader._load_remaining_pretrained_weight = local_load_remaining_pretrained_weight

from neural_compressor.torch.quantization import load

model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs)
elif args.local_quantized_inc_model_path:
org_model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -667,9 +669,10 @@ def initialize_model(args, logger):
logger.info(f"Model initialization took {(init_end - init_start):.3f}s")
return model, assistant_model, tokenizer, generation_config

#TODO:This will be removed from Synapse v1.19 release.
#This is to override _load_remaining_pretrained_weight for Transformer 4.45 release.
def local_load_remaining_pretrained_weight(self,model):

# TODO:This will be removed from Synapse v1.19 release.
# This is to override _load_remaining_pretrained_weight for Transformer 4.45 release.
def local_load_remaining_pretrained_weight(self, model):
from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict

resolved_archive_file = self.kwargs.pop("resolved_archive_file", None)
Expand All @@ -687,11 +690,11 @@ def local_load_remaining_pretrained_weight(self,model):
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)

params_dict={
params_dict = {
"model": model,
"state_dict": state_dict,
"start_prefix": "",
"expected_keys": list(state_dict.keys()),
"expected_keys": self.loaded_state_dict_keys,
"device_map": {"": self.device},
"offload_folder": offload_folder,
"state_dict_folder": tempfile.mkdtemp() if offload_state_dict else None,
Expand Down

0 comments on commit f98688d

Please sign in to comment.