From a4a17e0eaf5afeaf24a2b706ac3d9eb511fd4780 Mon Sep 17 00:00:00 2001 From: Erik Scholz Date: Mon, 25 Sep 2023 13:48:30 +0200 Subject: [PATCH] nix : add cuda, use a symlinked toolkit for cmake (#3202) --- flake.nix | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/flake.nix b/flake.nix index 7723357afe419..433d3d942ce40 100644 --- a/flake.nix +++ b/flake.nix @@ -35,6 +35,20 @@ ); pkgs = import nixpkgs { inherit system; }; nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ]; + cudatoolkit_joined = with pkgs; symlinkJoin { + # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit + # see https://github.com/NixOS/nixpkgs/issues/224291 + # copied from jaxlib + name = "${cudaPackages.cudatoolkit.name}-merged"; + paths = [ + cudaPackages.cudatoolkit.lib + cudaPackages.cudatoolkit.out + ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [ + # for some reason some of the required libs are in the targets/x86_64-linux + # directory; not sure why but this works around it + "${cudaPackages.cudatoolkit}/targets/${system}" + ]; + }; llama-python = pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); postPatch = '' @@ -70,6 +84,13 @@ "-DLLAMA_CLBLAST=ON" ]; }; + packages.cuda = pkgs.stdenv.mkDerivation { + inherit name src meta postPatch nativeBuildInputs postInstall; + buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ]; + cmakeFlags = cmakeFlags ++ [ + "-DLLAMA_CUBLAS=ON" + ]; + }; packages.rocm = pkgs.stdenv.mkDerivation { inherit name src meta postPatch nativeBuildInputs postInstall; buildInputs = with pkgs; buildInputs ++ [ hip hipblas rocblas ];