Skip to content

Commit

Permalink
Move command building to a separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Dec 21, 2024
1 parent 63e7267 commit 8f2822c
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 113 deletions.
3 changes: 2 additions & 1 deletion mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ target_sources(
if(IOS)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_nocpu.cpp)
else()
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp)
target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/compiled_cpu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/jit_compiler.cpp)
endif()
117 changes: 5 additions & 112 deletions mlx/backend/common/compiled_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,109 +1,20 @@
// Copyright © 2023-2024 Apple Inc.

#include <dlfcn.h>
#include <filesystem>
#include <fstream>
#include <list>
#include <mutex>
#include <shared_mutex>
#include <sstream>

#include <dlfcn.h>
#include <fmt/format.h>

#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/compiled_preamble.h"
#include "mlx/backend/common/jit_compiler.h"
#include "mlx/device.h"
#include "mlx/graph_utils.h"

namespace mlx::core {

#ifdef _MSC_VER

namespace {

// Split string into array.
std::vector<std::string> str_split(const std::string& str, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(str);
while (std::getline(tokenStream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}

// Run a command and get its output.
std::string exec(std::string cmd) {
std::unique_ptr<FILE, decltype(&_pclose)> pipe(
_popen(cmd.c_str(), "r"), _pclose);
if (!pipe) {
throw std::runtime_error("popen() failed.");
}
char buffer[128];
std::string ret;
while (fgets(buffer, sizeof(buffer), pipe.get())) {
ret += buffer;
}
// Trim trailing spaces.
ret.erase(
std::find_if(
ret.rbegin(),
ret.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
ret.end());
return ret;
}

// Get path information about MSVC.
struct VisualStudioInfo {
VisualStudioInfo() {
#ifdef _M_ARM64
arch = "arm64";
#else
arch = "x64";
#endif
// Get path of Visual Studio.
std::string vs_path = exec(fmt::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath",
std::getenv("ProgramFiles(x86)")));
if (vs_path.empty()) {
throw std::runtime_error("Can not find Visual Studio.");
}
// Read the envs from vcvarsall.
std::string envs = exec(fmt::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
vs_path,
arch));
for (const std::string& line : str_split(envs, '\n')) {
// Each line is in the format "ENV_NAME=values".
auto pos = line.find_first_of('=');
if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)
continue;
std::string name = line.substr(0, pos);
std::string value = line.substr(pos + 1);
if (name == "LIB") {
libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir") {
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
}
}
}
std::string arch;
std::string cl_exe;
std::vector<std::string> libpaths;
};

const VisualStudioInfo& GetVisualStudioInfo() {
static VisualStudioInfo info;
return info;
}

} // namespace

#endif // _MSC_VER

struct CompilerCache {
struct DLib {
DLib(const std::string& libname) {
Expand Down Expand Up @@ -198,27 +109,9 @@ void* compile(
source_file << source_code;
source_file.close();

#ifdef _MSC_VER
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths;
for (const std::string& lib : info.libpaths) {
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
}
std::string build_command = fmt::format(
"\""
"\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}"
"\"",
info.cl_exe,
source_file_path,
shared_lib_path,
libpaths);
#else
std::string build_command = fmt::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'",
source_file_path,
shared_lib_path);
#endif
auto return_code = system(build_command.c_str());
std::string command =
JitCompiler::build_command(source_file_path, shared_lib_path);
auto return_code = system(command.c_str());
if (return_code) {
std::ostringstream msg;
msg << "[Compile::eval_cpu] Failed to compile function " << kernel_name
Expand Down
124 changes: 124 additions & 0 deletions mlx/backend/common/jit_compiler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright © 2024 Apple Inc.

#include "mlx/backend/common/jit_compiler.h"

#include <sstream>

#include <fmt/format.h>

namespace mlx::core {

#ifdef _MSC_VER

namespace {

// Split string into array.
std::vector<std::string> str_split(const std::string& str, char delimiter) {
std::vector<std::string> tokens;
std::string token;
std::istringstream tokenStream(str);
while (std::getline(tokenStream, token, delimiter)) {
tokens.push_back(token);
}
return tokens;
}

// Run a command and get its output.
std::string exec(const std::string& cmd) {
std::unique_ptr<FILE, decltype(&_pclose)> pipe(
_popen(cmd.c_str(), "r"), _pclose);
if (!pipe) {
throw std::runtime_error("popen() failed.");
}
char buffer[128];
std::string ret;
while (fgets(buffer, sizeof(buffer), pipe.get())) {
ret += buffer;
}
// Trim trailing spaces.
ret.erase(
std::find_if(
ret.rbegin(),
ret.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
ret.end());
return ret;
}

// Get path information about MSVC.
struct VisualStudioInfo {
VisualStudioInfo() {
#ifdef _M_ARM64
arch = "arm64";
#else
arch = "x64";
#endif
// Get path of Visual Studio.
std::string vs_path = exec(fmt::format(
"\"{0}\\Microsoft Visual Studio\\Installer\\vswhere.exe\""
" -property installationPath",
std::getenv("ProgramFiles(x86)")));
if (vs_path.empty()) {
throw std::runtime_error("Can not find Visual Studio.");
}
// Read the envs from vcvarsall.
std::string envs = exec(fmt::format(
"\"{0}\\VC\\Auxiliary\\Build\\vcvarsall.bat\" {1} >NUL && set",
vs_path,
arch));
for (const std::string& line : str_split(envs, '\n')) {
// Each line is in the format "ENV_NAME=values".
auto pos = line.find_first_of('=');
if (pos == std::string::npos || pos == 0 || pos == line.size() - 1)
continue;
std::string name = line.substr(0, pos);
std::string value = line.substr(pos + 1);
if (name == "LIB") {
libpaths = str_split(value, ';');
} else if (name == "VCToolsInstallDir") {
cl_exe = fmt::format("{0}\\bin\\Host{1}\\{1}\\cl.exe", value, arch);
}
}
}
std::string arch;
std::string cl_exe;
std::vector<std::string> libpaths;
};

const VisualStudioInfo& GetVisualStudioInfo() {
static VisualStudioInfo info;
return info;
}

} // namespace

#endif // _MSC_VER

std::string JitCompiler::build_command(
const std::string& source_file_path,
const std::string& shared_lib_path) {
#ifdef _MSC_VER
const VisualStudioInfo& info = GetVisualStudioInfo();
std::string libpaths;
for (const std::string& lib : info.libpaths) {
libpaths += fmt::format(" /libpath:\"{0}\"", lib);
}
std::string command = fmt::format(
"\""
"\"{0}\" /LD /EHsc /nologo /std:c++17 \"{1}\" /link /out:\"{2}\"{3}"
"\"",
info.cl_exe,
source_file_path,
shared_lib_path,
libpaths);
#else
std::string command = fmt::format(
"g++ -std=c++17 -O3 -Wall -fPIC -shared '{0}' -o '{1}'",
source_file_path,
shared_lib_path);
#endif
return command;
}

} // namespace mlx::core
17 changes: 17 additions & 0 deletions mlx/backend/common/jit_compiler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright © 2024 Apple Inc.
#pragma once

#include <string>

namespace mlx::core {

class JitCompiler {
public:
// Build a shell command that compiles |source_file_path| to a shared library
// at |shared_lib_path|.
static std::string build_command(
const std::string& source_file_path,
const std::string& shared_lib_path);
};

} // namespace mlx::core

0 comments on commit 8f2822c

Please sign in to comment.