diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 73328e3af59e..0554fa9d7d8c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -785,8 +785,8 @@ def _load_state_dict_into_meta_model( tensor_device = None if device_map is not None and device_map.get("", None) is not None: tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""] - - device_map_regex = "|".join(sorted(device_map.keys(), reverse=True)) + if device_map is not None: + device_map_regex = "|".join(sorted(device_map.keys(), reverse=True)) # we need this later to initialize tensor parallelism if device_mesh is not None: