From cf61e8e439f12e1cbe58b79db18efc30e29044f7 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 16 Dec 2024 18:47:58 +0800 Subject: [PATCH] fix device_map={"":None} --- gptqmodel/models/loader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index b903b096a..2554439a1 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -394,8 +394,7 @@ def skip(*args, **kwargs): if not isinstance(device_map, dict): if device is not None: - device = torch.device(device) - device_map = {"": device.index if device.type in [DEVICE.CUDA, DEVICE.XPU] else device.type} + device_map = {"": 0 if device in [DEVICE.CUDA, DEVICE.XPU, DEVICE.MPS] else DEVICE.CPU} else: device_map = accelerate.infer_auto_device_map( model,