diff --git a/flake.nix b/flake.nix index c72d8e4e7197d..21adc43148750 100644 --- a/flake.nix +++ b/flake.nix @@ -35,6 +35,20 @@ ); pkgs = import nixpkgs { inherit system; }; nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ]; + cudatoolkit_joined = with pkgs; symlinkJoin { + # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit + # see https://github.com/NixOS/nixpkgs/issues/224291 + # copied from jaxlib + name = "${cudaPackages.cudatoolkit.name}-merged"; + paths = [ + cudaPackages.cudatoolkit.lib + cudaPackages.cudatoolkit.out + ] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [ + # for some reason some of the required libs are in the targets/x86_64-linux + # directory; not sure why but this works around it + "${cudaPackages.cudatoolkit}/targets/${system}" + ]; + }; llama-python = pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); postPatch = '' @@ -71,31 +85,10 @@ }; packages.cuda = pkgs.stdenv.mkDerivation { inherit name src meta postPatch nativeBuildInputs postInstall; - buildInputs = with pkgs; buildInputs ++ [ cudaPackages.cudatoolkit ]; - - # HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit, so we force gnu make - # see https://github.com/NixOS/nixpkgs/issues/224291 - dontUseCmakeConfigure = true; - dontUseNinjaBuild = true; - - buildFlags = [ "LLAMA_CUBLAS=1" ]; - installPhase = '' - runHook preInstall - - mkdir -p $out/bin - - # TODO(Green-Sky): add install target to Makefile, or wait for cmake support - mv main $out/bin/ - mv server $out/bin/ - mv speculative $out/bin/ - mv perplexity $out/bin/ - mv embedding $out/bin/ - mv quantize $out/bin/ - mv llama-bench $out/bin/ - mv train-text-from-scratch $out/bin/ - - runHook postInstall - ''; + buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ]; + cmakeFlags = cmakeFlags ++ [ + "-DLLAMA_CUBLAS=ON" + ]; }; packages.rocm = pkgs.stdenv.mkDerivation { inherit name src meta postPatch nativeBuildInputs postInstall;