diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 65f207bf0..ab67fbe74 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -170,7 +170,7 @@ def load( # TODO fix me...unify device + device_map auto logic if not device and not device_map or device_map == "auto": - device = get_best_device() + device = get_best_device(backend=backend) if is_quantized: return cls.from_quantized(