diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index 6396eaf..a6d880d 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -94,7 +94,13 @@ def get_device(device_spec: Union[str, int, List[int]]) -> torch.device: device_spec = initialize_distributed() if torch.cuda.is_available() and "cuda" in str(device_spec): return torch.device(device_spec), ddp_flag - return torch.device("cpu"), False + if not torch.cuda.is_available() and not torch.backends.mps.is_available(): + if device_spec != "cpu": + logger.warning(f"❎ Device spec: {device_spec} not support, Choosing CPU instead") + return torch.device("cpu"), False + + device = torch.device(device_spec) + return device, ddp_flag class PostProccess: