Skip to content

Commit

Permalink
Support ccache for nvcc
Browse files Browse the repository at this point in the history
ghstack-source-id: 399daef2b13459e2846bd779e48029fadc000654
Pull Request resolved: https://github.com/fairinternal/xformers/pull/515

__original_commit__ = fairinternal/xformers@249619bff993f4992200a1d21b63757b573d35a7
  • Loading branch information
danthe3rd authored and xFormers Bot committed Mar 27, 2023
1 parent 95a715b commit 5eb0dbf
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,25 @@ def __init__(self, *args, **kwargs) -> None:
self.metadata_json = "cpp_lib.json"
super().__init__(*args, **kwargs)

@staticmethod
def _join_cuda_home(*paths) -> str:
"""
Hackfix to support custom `nvcc` binary (eg ccache)
TODO: Remove once we use PT 2.1.0 (https://github.com/pytorch/pytorch/pull/96987)
"""
if paths == ("bin", "nvcc") and "PYTORCH_NVCC" in os.environ:
return os.environ["PYTORCH_NVCC"]
if CUDA_HOME is None:
raise EnvironmentError(
"CUDA_HOME environment variable is not set. "
"Please set it to your CUDA install root."
)
return os.path.join(CUDA_HOME, *paths)

def build_extensions(self) -> None:
torch.utils.cpp_extension._join_cuda_home = (
BuildExtensionWithMetadata._join_cuda_home
)
super().build_extensions()
with open(
os.path.join(self.build_lib, self.pkg_name, self.metadata_json), "w+"
Expand Down

0 comments on commit 5eb0dbf

Please sign in to comment.