Skip to content

Commit

Permalink
use a symlinked toolkit instead to make cmake working
Browse files Browse the repository at this point in the history
  • Loading branch information
Green-Sky committed Sep 18, 2023
1 parent 92e4d29 commit b86c4bd
Showing 1 changed file with 18 additions and 25 deletions.
43 changes: 18 additions & 25 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
);
pkgs = import nixpkgs { inherit system; };
nativeBuildInputs = with pkgs; [ cmake ninja pkg-config ];
cudatoolkit_joined = with pkgs; symlinkJoin {
# HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit
# see https://github.com/NixOS/nixpkgs/issues/224291
# copied from jaxlib
name = "${cudaPackages.cudatoolkit.name}-merged";
paths = [
cudaPackages.cudatoolkit.lib
cudaPackages.cudatoolkit.out
] ++ lib.optionals (lib.versionOlder cudaPackages.cudatoolkit.version "11") [
# for some reason some of the required libs are in the targets/x86_64-linux
# directory; not sure why but this works around it
"${cudaPackages.cudatoolkit}/targets/${system}"
];
};
llama-python =
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
postPatch = ''
Expand Down Expand Up @@ -71,31 +85,10 @@
};
packages.cuda = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall;
buildInputs = with pkgs; buildInputs ++ [ cudaPackages.cudatoolkit ];

# HACK(Green-Sky): nix currently has issues with cmake findcudatoolkit, so we force gnu make
# see https://github.com/NixOS/nixpkgs/issues/224291
dontUseCmakeConfigure = true;
dontUseNinjaBuild = true;

buildFlags = [ "LLAMA_CUBLAS=1" ];
installPhase = ''
runHook preInstall
mkdir -p $out/bin
# TODO(Green-Sky): add install target to Makefile, or wait for cmake support
mv main $out/bin/
mv server $out/bin/
mv speculative $out/bin/
mv perplexity $out/bin/
mv embedding $out/bin/
mv quantize $out/bin/
mv llama-bench $out/bin/
mv train-text-from-scratch $out/bin/
runHook postInstall
'';
buildInputs = with pkgs; buildInputs ++ [ cudatoolkit_joined ];
cmakeFlags = cmakeFlags ++ [
"-DLLAMA_CUBLAS=ON"
];
};
packages.rocm = pkgs.stdenv.mkDerivation {
inherit name src meta postPatch nativeBuildInputs postInstall;
Expand Down

0 comments on commit b86c4bd

Please sign in to comment.