diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d7fb0fc3500c..263ae5d2f988 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3769,7 +3769,7 @@ def _fix_key(key): ): set_module_tensor_to_device(model, key, "cpu", value) else: - hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict) + hf_quantizer.create_quantized_param(model, value, key, "cpu", state_dict, unexpected_keys) # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights. if _fast_init: diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 236c1dc642b8..b98eebba1834 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from packaging import version @@ -143,7 +143,7 @@ def create_quantized_param( param_name: str, target_device: "torch.device", state_dict: Dict[str, Any], - unexpected_keys: List[str], + unexpected_keys: Optional[List[str]] = None, ): """ combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() @@ -198,7 +198,8 @@ def create_quantized_param( for k, v in state_dict.items(): if param_name + "." in k: quantized_stats[k] = v - unexpected_keys.remove(k) + if unexpected_keys is not None and k in unexpected_keys: + unexpected_keys.remove(k) new_value = bnb.nn.Params4bit.from_prequantized( data=param_value, diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index c74e735e48ed..f4249b69d094 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib -from typing import TYPE_CHECKING, Any, Dict, List, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from packaging import version @@ -162,7 +162,7 @@ def create_quantized_param( param_name: str, target_device: "torch.device", state_dict: Dict[str, Any], - unexpected_keys: List[str], + unexpected_keys: Optional[List[str]] = None, ): """ combines logic from _load_state_dict_into_meta_model and .integrations.bitsandbytes.py::set_module_quantized_tensor_to_device() @@ -207,7 +207,8 @@ def create_quantized_param( module._parameters[tensor_name] = new_value if fp16_statistics is not None: setattr(module.weight, "SCB", fp16_statistics.to(target_device)) - unexpected_keys.remove(fp16_statistics_key) + if unexpected_keys is not None: + unexpected_keys.remove(fp16_statistics_key) def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): model.is_loaded_in_8bit = True