From 4d57c2c84a539d2094ee241d367c2c50ee6af80f Mon Sep 17 00:00:00 2001 From: hlky Date: Sun, 2 Mar 2025 07:33:36 +0000 Subject: [PATCH] Fix _load_state_dict_into_meta_model with device_map=None (#36488) * Fix _load_state_dict_into_meta_model with device_map=None * Update src/transformers/modeling_utils.py --- src/transformers/modeling_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 73328e3af59e..0554fa9d7d8c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -785,8 +785,8 @@ def _load_state_dict_into_meta_model( tensor_device = None if device_map is not None and device_map.get("", None) is not None: tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] - - device_map_regex = "|".join(sorted(device_map.keys(), reverse=True)) + if device_map is not None: + device_map_regex = "|".join(sorted(device_map.keys(), reverse=True)) # we need this later to initialize tensor parallelism if device_mesh is not None: