From 5eb0dbf315d14b5f7b38ac2ff3d8379beca7df9b Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Mon, 27 Mar 2023 13:43:07 +0000 Subject: [PATCH] Support ccache for nvcc ghstack-source-id: 399daef2b13459e2846bd779e48029fadc000654 Pull Request resolved: https://github.com/fairinternal/xformers/pull/515 __original_commit__ = fairinternal/xformers@249619bff993f4992200a1d21b63757b573d35a7 --- setup.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/setup.py b/setup.py index b0fe954e98..640d0ead42 100644 --- a/setup.py +++ b/setup.py @@ -304,7 +304,25 @@ def __init__(self, *args, **kwargs) -> None: self.metadata_json = "cpp_lib.json" super().__init__(*args, **kwargs) + @staticmethod + def _join_cuda_home(*paths) -> str: + """ + Hackfix to support custom `nvcc` binary (eg ccache) + TODO: Remove once we use PT 2.1.0 (https://github.com/pytorch/pytorch/pull/96987) + """ + if paths == ("bin", "nvcc") and "PYTORCH_NVCC" in os.environ: + return os.environ["PYTORCH_NVCC"] + if CUDA_HOME is None: + raise EnvironmentError( + "CUDA_HOME environment variable is not set. " + "Please set it to your CUDA install root." + ) + return os.path.join(CUDA_HOME, *paths) + def build_extensions(self) -> None: + torch.utils.cpp_extension._join_cuda_home = ( + BuildExtensionWithMetadata._join_cuda_home + ) super().build_extensions() with open( os.path.join(self.build_lib, self.pkg_name, self.metadata_json), "w+"