diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py index 286e8ba499..5cfe1a9fac 100644 --- a/timm/utils/distributed.py +++ b/timm/utils/distributed.py @@ -109,8 +109,13 @@ def init_distributed_device_so( global_rank = 0 local_rank = 0 if dist_backend is None: - # FIXME sane defaults for other device backends? - dist_backend = 'nccl' if 'cuda' in device else 'gloo' + # FIXME: verify that ROCm transform nccl to rccl + dist_backends = { + "xpu": "ccl", + "hpu": "hccl", + "cuda": "nccl", + } + dist_backend = dist_backends.get(device, 'gloo') dist_url = dist_url or 'env://' # TBD, support horovod?