diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index a412cbe7b3..e32123f43f 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -66,5 +66,6 @@ target_sources( if(IOS) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp) else() - target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp) + target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp) endif() diff --git a/mlx/backend/common/compiled_cpu.cpp b/mlx/backend/common/compiled_cpu.cpp index eb08d070d3..650f038c85 100644 --- a/mlx/backend/common/compiled_cpu.cpp +++ b/mlx/backend/common/compiled_cpu.cpp @@ -9,6 +9,7 @@ #include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled_preamble.h" +#include "mlx/backend/common/jit_compiler.h" #include "mlx/device.h" #include "mlx/graph_utils.h" @@ -44,11 +45,8 @@ namespace detail { bool compile_available_for_device(const Device& device) { return true; } -} // namespace detail -std::string get_temp_file(const std::string& name) { - return std::filesystem::temp_directory_path().append(name).string(); -} +} // namespace detail // Return a pointer to a compiled function void* compile( @@ -88,9 +86,10 @@ void* compile( kernel_file_name = kernel_name; } - std::ostringstream shared_lib_name; - shared_lib_name << "lib" << kernel_file_name << ".so"; - auto shared_lib_path = get_temp_file(shared_lib_name.str()); + auto output_dir = std::filesystem::temp_directory_path(); + + std::string shared_lib_name = "lib" + kernel_file_name + ".so"; + auto shared_lib_path = (output_dir / shared_lib_name).string(); bool lib_exists = false; { std::ifstream f(shared_lib_path.c_str()); @@ -99,19 +98,16 @@ void* compile( if (!lib_exists) { // Open source file and write source code to it - std::ostringstream source_file_name; - source_file_name << kernel_file_name << ".cpp"; - auto source_file_path = get_temp_file(source_file_name.str()); + std::string source_file_name = kernel_file_name + ".cpp"; + auto source_file_path = (output_dir / source_file_name).string(); std::ofstream source_file(source_file_path); source_file << source_code; source_file.close(); - std::ostringstream build_command; - build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '" - << source_file_path << "' -o '" << shared_lib_path << "'"; - std::string build_command_str = build_command.str(); - auto return_code = system(build_command_str.c_str()); + std::string command = JitCompiler::build_command( + output_dir, source_file_name, shared_lib_name); + auto return_code = system(command.c_str()); if (return_code) { std::ostringstream msg; msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name @@ -156,6 +152,11 @@ inline void build_kernel( NodeNamer namer; +#ifdef _MSC_VER + // Export the symbol + os << "__declspec(dllexport) "; +#endif + // Start the kernel os << "void " << kernel_name << "(void** args) {" << std::endl; diff --git a/mlx/backend/common/jit_compiler.cpp b/mlx/backend/common/jit_compiler.cpp new file mode 100644 index 0000000000..27fb9e7234 --- /dev/null +++ b/mlx/backend/common/jit_compiler.cpp @@ -0,0 +1,128 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/jit_compiler.h" + +#include +#include + +#include + +namespace mlx::core { + +#ifdef _MSC_VER + +namespace { + +// Split string into array. +std::vector str_split(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +// Run a command and get its output. +std::string exec(const std::string& cmd) { + std::unique_ptr pipe( + _popen(cmd.c_str(), "r"), _pclose); + if (!pipe) { + throw std::runtime_error("popen() failed."); + } + char buffer[128]; + std::string ret; + while (fgets(buffer, sizeof(buffer), pipe.get())) { + ret += buffer; + } + // Trim trailing spaces. + ret.erase( + std::find_if( + ret.rbegin(), + ret.rend(), + [](unsigned char ch) { return !std::isspace(ch); }) + .base(), + ret.end()); + return ret; +} + +// Get path information about MSVC. +struct VisualStudioInfo { + VisualStudioInfo() { +#ifdef _M_ARM64 + arch = "arm64"; +#else + arch = "x64"; +#endif + // Get path of Visual Studio. + std::string vs_path = exec(fmt::format( + "\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" + " -property installationPath", + std::getenv("ProgramFiles(x86)"))); + if (vs_path.empty()) { + throw std::runtime_error("Can not find Visual Studio."); + } + // Read the envs from vcvarsall. + std::string envs = exec(fmt::format( + "\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set", + vs_path, + arch)); + for (const std::string& line : str_split(envs, '\n')) { + // Each line is in the format "ENV_NAME=values". + auto pos = line.find_first_of('='); + if (pos == std::string::npos || pos == 0 || pos == line.size() - 1) + continue; + std::string name = line.substr(0, pos); + std::string value = line.substr(pos + 1); + if (name == "LIB") { + libpaths = str_split(value, ';'); + } else if (name == "VCToolsInstallDir") { + cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch); + } + } + } + std::string arch; + std::string cl_exe; + std::vector libpaths; +}; + +const VisualStudioInfo& GetVisualStudioInfo() { + static VisualStudioInfo info; + return info; +} + +} // namespace + +#endif // _MSC_VER + +std::string JitCompiler::build_command( + const std::filesystem::path& dir, + const std::string& source_file_name, + const std::string& shared_lib_name) { +#ifdef _MSC_VER + const VisualStudioInfo& info = GetVisualStudioInfo(); + std::string libpaths; + for (const std::string& lib : info.libpaths) { + libpaths += fmt::format(" /libpath:\"{0}\"", lib); + } + return fmt::format( + "\"" + "cd /D \"{0}\" && " + "\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" " + "/link /out:\"{3}\" {4} >nul" + "\"", + dir.string(), + info.cl_exe, + source_file_name, + shared_lib_name, + libpaths); +#else + return fmt::format( + "g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'", + (dir / source_file_name).string(), + (dir / shared_lib_name).string()); +#endif +} + +} // namespace mlx::core diff --git a/mlx/backend/common/jit_compiler.h b/mlx/backend/common/jit_compiler.h new file mode 100644 index 0000000000..b0bf8c0dee --- /dev/null +++ b/mlx/backend/common/jit_compiler.h @@ -0,0 +1,17 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include + +namespace mlx::core { + +class JitCompiler { + public: + // Build a shell command that compiles a source code file to a shared library. + static std::string build_command( + const std::filesystem::path& dir, + const std::string& source_file_name, + const std::string& shared_lib_name); +}; + +} // namespace mlx::core diff --git a/mlx/backend/common/make_compiled_preamble.ps1 b/mlx/backend/common/make_compiled_preamble.ps1 index 0b2248b674..18d0574535 100644 --- a/mlx/backend/common/make_compiled_preamble.ps1 +++ b/mlx/backend/common/make_compiled_preamble.ps1 @@ -13,7 +13,7 @@ $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/comp # Otherwise there will be too much empty lines making the result unreadable. $CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' } # Concatenate to string. -$CONTENT = $CONTENT -join '`n' +$CONTENT = $CONTENT -join "`n" # Append extra content. $CONTENT = @"