diff --git a/setup.py b/setup.py index b0fe954e98..640d0ead42 100644 --- a/setup.py +++ b/setup.py @@ -304,7 +304,25 @@ def __init__(self, *args, **kwargs) -> None: self.metadata_json = "cpp_lib.json" super().__init__(*args, **kwargs) + @staticmethod + def _join_cuda_home(*paths) -> str: + """ + Hackfix to support custom `nvcc` binary (eg ccache) + TODO: Remove once we use PT 2.1.0 (https://github.com/pytorch/pytorch/pull/96987) + """ + if paths == ("bin", "nvcc") and "PYTORCH_NVCC" in os.environ: + return os.environ["PYTORCH_NVCC"] + if CUDA_HOME is None: + raise EnvironmentError( + "CUDA_HOME environment variable is not set. " + "Please set it to your CUDA install root." + ) + return os.path.join(CUDA_HOME, *paths) + def build_extensions(self) -> None: + torch.utils.cpp_extension._join_cuda_home = ( + BuildExtensionWithMetadata._join_cuda_home + ) super().build_extensions() with open( os.path.join(self.build_lib, self.pkg_name, self.metadata_json), "w+"