Skip to content
This repository was archived by the owner on Jun 24, 2024. It is now read-only.

fix(metal): patch ggml-metal.m with shader code #328

Merged
merged 2 commits into from
Jun 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions crates/ggml/sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::env;
use std::path::PathBuf;
use std::path::{Path, PathBuf};

// By default, this crate will attempt to compile ggml with the features of your host system if
// the host and target are the same. If they are not, it will turn off auto-feature-detection,
Expand All @@ -19,13 +19,14 @@ fn main() {
let compiler = build.get_compiler();

// Enable accelerators
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is not defined"));
if cfg_cublas() && !cfg!(target_os = "macos") {
enable_cublas(build);
enable_cublas(build, &out_dir);
} else if cfg_clblast() {
enable_clblast(build);
} else if cfg!(target_os = "macos") {
if cfg_metal() {
enable_metal(build);
enable_metal(build, &out_dir);
} else {
println!("cargo:rustc-link-lib=framework=Accelerate");

Expand Down Expand Up @@ -171,13 +172,62 @@ fn enable_clblast(build: &mut cc::Build) {
build.flag(&format!(r"-I{opencl_include_path}"));
}

fn enable_metal(build: &mut cc::Build) {
fn enable_metal(build: &mut cc::Build, out_dir: &Path) {
const GGML_METAL_METAL_PATH: &str = "llama-cpp/ggml-metal.metal";
const GGML_METAL_PATH: &str = "llama-cpp/ggml-metal.m";

println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
println!("cargo:rustc-link-lib=framework=MetalPerformanceShaders");
println!("cargo:rerun-if-changed={GGML_METAL_METAL_PATH}");
println!("cargo:rerun-if-changed={GGML_METAL_PATH}");

// HACK: patch ggml-metal.m so that it includes ggml-metal.metal, so that
// a runtime dependency is not necessary
let ggml_metal_path = {
let ggml_metal_metal = std::fs::read_to_string(GGML_METAL_METAL_PATH)
.expect("Could not read ggml-metal.metal")
.replace('\\', "\\\\")
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\"', "\\\"");

let ggml_metal =
std::fs::read_to_string(GGML_METAL_PATH).expect("Could not read ggml-metal.m");

let needle = r#"NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];"#;
if !ggml_metal.contains(needle) {
panic!("ggml-metal.m does not contain the needle to be replaced; the patching logic needs to be reinvestigated. Contact a `llm` developer!");
}

// Replace the runtime read of the file with a compile-time string
let ggml_metal = ggml_metal.replace(
needle,
&format!(r#"NSString * src = @"{ggml_metal_metal}";"#),
);

// Replace the judicious use of `fprintf` with the already-existing `metal_printf`,
// backing up the definition of `metal_printf` first
let ggml_metal = ggml_metal
.replace(
r#"#define metal_printf(...) fprintf(stderr, __VA_ARGS__)"#,
"METAL_PRINTF_DEFINITION",
)
.replace("fprintf(stderr,", "metal_printf(")
.replace(
"METAL_PRINTF_DEFINITION",
r#"#define metal_printf(...) fprintf(stderr, __VA_ARGS__)"#,
);

build.file("llama-cpp/ggml-metal.m");
let patched_ggml_metal_path = out_dir.join("ggml-metal.m");
std::fs::write(&patched_ggml_metal_path, ggml_metal)
.expect("Could not write temporary patched ggml-metal.m");

patched_ggml_metal_path
};

build.file(ggml_metal_path);
build.flag("-DGGML_USE_METAL");

#[cfg(not(debug_assertions))]
Expand Down Expand Up @@ -223,10 +273,8 @@ fn cuda_lib_path() -> String {
.to_string()
}

fn enable_cublas(build: &mut cc::Build) {
let out_dir = env::var("OUT_DIR").expect("OUT_DIR is not defined");

let object_file = PathBuf::from(out_dir)
fn enable_cublas(build: &mut cc::Build, out_dir: &Path) {
let object_file = out_dir
.join("llama-cpp")
.join("ggml-cuda.o")
.to_str()
Expand Down