Skip to content

Commit

Permalink
Make mx.compile work on Windows (#1697)
Browse files Browse the repository at this point in the history
* Invoke MSVC on Windows in mx.compile

* Export kernel symbol on MSVC

* Remove unused template

* Parse env pairs in a robust way

* No need of cassert

* Remove unnecessary helpers

* Fix right trim

* Move command building to a separate file

* Missing header

* Do not pollute cwd with cl.exe

* Simplify str concat

* Pass output dir

* Fix styling
  • Loading branch information
zcbenz authored Dec 24, 2024
1 parent 88f993d commit 935c8c4
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 17 deletions.
3 changes: 2 additions & 1 deletion mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ target_sources(
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
endif()
31 changes: 16 additions & 15 deletions mlx/backend/common/compiled_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"

Expand Down Expand Up @@ -44,11 +45,8 @@ namespace detail {
bool compile_available_for_device(const Device& device) {
return true;
}
} // namespace detail

std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name).string();
}
} // namespace detail

// Return a pointer to a compiled function
void* compile(
Expand Down Expand Up @@ -88,9 +86,10 @@ void* compile(
kernel_file_name = kernel_name;
}

std::ostringstream shared_lib_name;
shared_lib_name << "lib" << kernel_file_name << ".so";
auto shared_lib_path = get_temp_file(shared_lib_name.str());
auto output_dir = std::filesystem::temp_directory_path();

std::string shared_lib_name = "lib" + kernel_file_name + ".so";
auto shared_lib_path = (output_dir / shared_lib_name).string();
bool lib_exists = false;
{
std::ifstream f(shared_lib_path.c_str());
Expand All @@ -99,19 +98,16 @@ void* compile(

if (!lib_exists) {
// Open source file and write source code to it
std::ostringstream source_file_name;
source_file_name << kernel_file_name << ".cpp";
auto source_file_path = get_temp_file(source_file_name.str());
std::string source_file_name = kernel_file_name + ".cpp";
auto source_file_path = (output_dir / source_file_name).string();

std::ofstream source_file(source_file_path);
source_file << source_code;
source_file.close();

std::ostringstream build_command;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
std::string command = JitCompiler::build_command(
output_dir, source_file_name, shared_lib_name);
auto return_code = system(command.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
Expand Down Expand Up @@ -156,6 +152,11 @@ inline void build_kernel(

NodeNamer namer;

#ifdef _MSC_VER
// Export the symbol
os << "__declspec(dllexport) ";
#endif

// Start the kernel
os << "void " << kernel_name << "(void** args) {" << std::endl;

Expand Down
128 changes: 128 additions & 0 deletions mlx/backend/common/jit_compiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright © 2024 Apple Inc.

#include "mlx/backend/common/jit_compiler.h"

#include <sstream>
#include <vector>

#include <fmt/format.h>

namespace mlx::core {

#ifdef _MSC_VER

namespace {

// Split string into array.
std::vector<std::string> str_split(const std::string& str, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(str);
while (std::getline(tokenStream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}

// Run a command and get its output.
std::string exec(const std::string& cmd) {
std::unique_ptr<FILE, decltype(&_pclose)> pipe(
_popen(cmd.c_str(), "r"), _pclose);
if (!pipe) {
throw std::runtime_error("popen() failed.");
}
char buffer[128];
std::string ret;
while (fgets(buffer, sizeof(buffer), pipe.get())) {
ret += buffer;
}
// Trim trailing spaces.
ret.erase(
std::find_if(
ret.rbegin(),
ret.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
ret.end());
return ret;
}

// Get path information about MSVC.
struct VisualStudioInfo {
VisualStudioInfo() {
#ifdef _M_ARM64
arch = "arm64";
#else
arch = "x64";
#endif
// Get path of Visual Studio.
std::string vs_path = exec(fmt::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath",
std::getenv("ProgramFiles(x86)")));
if (vs_path.empty()) {
throw std::runtime_error("Can not find Visual Studio.");
}
// Read the envs from vcvarsall.
std::string envs = exec(fmt::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
vs_path,
arch));
for (const std::string& line : str_split(envs, '\n')) {
// Each line is in the format "ENV_NAME=values".
auto pos = line.find_first_of('=');
if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)
continue;
std::string name = line.substr(0, pos);
std::string value = line.substr(pos + 1);
if (name == "LIB") {
libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir") {
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
}
}
}
std::string arch;
std::string cl_exe;
std::vector<std::string> libpaths;
};

const VisualStudioInfo& GetVisualStudioInfo() {
static VisualStudioInfo info;
return info;
}

} // namespace

#endif // _MSC_VER

std::string JitCompiler::build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name) {
#ifdef _MSC_VER
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths;
for (const std::string& lib : info.libpaths) {
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
}
return fmt::format(
"\""
"cd /D \"{0}\" && "
"\"{1}\" /LD /EHsc /MD /Ox /nologo /std:c++17 \"{2}\" "
"/link /out:\"{3}\" {4} >nul"
"\"",
dir.string(),
info.cl_exe,
source_file_name,
shared_lib_name,
libpaths);
#else
return fmt::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'",
(dir / source_file_name).string(),
(dir / shared_lib_name).string());
#endif
}

} // namespace mlx::core
17 changes: 17 additions & 0 deletions mlx/backend/common/jit_compiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright © 2024 Apple Inc.
#pragma once

#include <filesystem>

namespace mlx::core {

class JitCompiler {
public:
// Build a shell command that compiles a source code file to a shared library.
static std::string build_command(
const std::filesystem::path& dir,
const std::string& source_file_name,
const std::string& shared_lib_name);
};

} // namespace mlx::core
2 changes: 1 addition & 1 deletion mlx/backend/common/make_compiled_preamble.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ $CONTENT = & $CL /std:c++17 /EP "/I$SRCDIR" /Tp "$SRCDIR/mlx/backend/common/comp
# Otherwise there will be too much empty lines making the result unreadable.
$CONTENT = $CONTENT | Where-Object { $_.Trim() -ne '' }
# Concatenate to string.
$CONTENT = $CONTENT -join '`n'
$CONTENT = $CONTENT -join "`n"

# Append extra content.
$CONTENT = @"
Expand Down

0 comments on commit 935c8c4

Please sign in to comment.