Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #328 from rustformers/metal-patch-with-shader-code
Browse files Browse the repository at this point in the history
fix(metal): patch ggml-metal.m with shader code
  • Loading branch information
philpax authored Jun 23, 2023
2 parents 7927d0d + 6bf705c commit 47a41c9
Showing 1 changed file with 57 additions and 9 deletions.
66 changes: 57 additions & 9 deletions crates/ggml/sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::env;
use std::path::PathBuf;
use std::path::{Path, PathBuf};

// By default, this crate will attempt to compile ggml with the features of your host system if
// the host and target are the same. If they are not, it will turn off auto-feature-detection,
Expand All @@ -19,13 +19,14 @@ fn main() {
let compiler = build.get_compiler();

// Enable accelerators
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is not defined"));
if cfg_cublas() && !cfg!(target_os = "macos") {
enable_cublas(build);
enable_cublas(build, &out_dir);
} else if cfg_clblast() {
enable_clblast(build);
} else if cfg!(target_os = "macos") {
if cfg_metal() {
enable_metal(build);
enable_metal(build, &out_dir);
} else {
println!("cargo:rustc-link-lib=framework=Accelerate");

Expand Down Expand Up @@ -171,13 +172,62 @@ fn enable_clblast(build: &mut cc::Build) {
build.flag(&format!(r"-I{opencl_include_path}"));
}

fn enable_metal(build: &mut cc::Build) {
fn enable_metal(build: &mut cc::Build, out_dir: &Path) {
const GGML_METAL_METAL_PATH: &str = "llama-cpp/ggml-metal.metal";
const GGML_METAL_PATH: &str = "llama-cpp/ggml-metal.m";

println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
println!("cargo:rustc-link-lib=framework=MetalPerformanceShaders");
println!("cargo:rerun-if-changed={GGML_METAL_METAL_PATH}");
println!("cargo:rerun-if-changed={GGML_METAL_PATH}");

// HACK: patch ggml-metal.m so that it includes ggml-metal.metal, so that
// a runtime dependency is not necessary
let ggml_metal_path = {
let ggml_metal_metal = std::fs::read_to_string(GGML_METAL_METAL_PATH)
.expect("Could not read ggml-metal.metal")
.replace('\\', "\\\\")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\"', "\\\"");

let ggml_metal =
std::fs::read_to_string(GGML_METAL_PATH).expect("Could not read ggml-metal.m");

let needle = r#"NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];"#;
if !ggml_metal.contains(needle) {
panic!("ggml-metal.m does not contain the needle to be replaced; the patching logic needs to be reinvestigated. Contact a `llm` developer!");
}

// Replace the runtime read of the file with a compile-time string
let ggml_metal = ggml_metal.replace(
needle,
&format!(r#"NSString * src = @"{ggml_metal_metal}";"#),
);

// Replace the judicious use of `fprintf` with the already-existing `metal_printf`,
// backing up the definition of `metal_printf` first
let ggml_metal = ggml_metal
.replace(
r#"#define metal_printf(...) fprintf(stderr, __VA_ARGS__)"#,
"METAL_PRINTF_DEFINITION",
)
.replace("fprintf(stderr,", "metal_printf(")
.replace(
"METAL_PRINTF_DEFINITION",
r#"#define metal_printf(...) fprintf(stderr, __VA_ARGS__)"#,
);

build.file("llama-cpp/ggml-metal.m");
let patched_ggml_metal_path = out_dir.join("ggml-metal.m");
std::fs::write(&patched_ggml_metal_path, ggml_metal)
.expect("Could not write temporary patched ggml-metal.m");

patched_ggml_metal_path
};

build.file(ggml_metal_path);
build.flag("-DGGML_USE_METAL");

#[cfg(not(debug_assertions))]
Expand Down Expand Up @@ -223,10 +273,8 @@ fn cuda_lib_path() -> String {
.to_string()
}

fn enable_cublas(build: &mut cc::Build) {
let out_dir = env::var("OUT_DIR").expect("OUT_DIR is not defined");

let object_file = PathBuf::from(out_dir)
fn enable_cublas(build: &mut cc::Build, out_dir: &Path) {
let object_file = out_dir
.join("llama-cpp")
.join("ggml-cuda.o")
.to_str()
Expand Down

0 comments on commit 47a41c9

Please sign in to comment.