From fb96878e693e3e42592c95f476c64e63d1e2902b Mon Sep 17 00:00:00 2001 From: Cheng Zhao Date: Mon, 9 Dec 2024 11:58:25 +0000 Subject: [PATCH 1/2] Fix shared library not exporting symbols on Windows --- mlx/CMakeLists.txt | 5 +++++ mlx/utils.cpp | 9 ++++++--- mlx/utils.h | 2 +- python/src/array.cpp | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index fb91d90448..8ec82d1773 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -23,6 +23,11 @@ if(MSVC) target_compile_options(mlx PUBLIC /wd4068 /wd4244 /wd4267 /wd4804) endif() +if(WIN32) + # Export symbols by default to behave like macOS/linux. + set_target_properties(mlx PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE) +endif() + if(MLX_BUILD_CPU) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/common) else() diff --git a/mlx/utils.cpp b/mlx/utils.cpp index daa90fea69..b93faee6ec 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -56,7 +56,10 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) { os << val; } -PrintFormatter global_formatter; +PrintFormatter& GetGlobalFormatter() { + static PrintFormatter formatter; + return formatter; +} Dtype result_type(const std::vector& arrays) { Dtype t = bool_; @@ -171,7 +174,7 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { i = n - num_print - 1; index += s * (n - 2 * num_print - 1); } else if (is_last) { - global_formatter.print(os, a.data()[index]); + GetGlobalFormatter().print(os, a.data()[index]); } else { print_subarray(os, a, index, dim + 1); } @@ -187,7 +190,7 @@ void print_array(std::ostream& os, const array& a) { os << "array("; if (a.ndim() == 0) { auto data = a.data(); - global_formatter.print(os, data[0]); + GetGlobalFormatter().print(os, data[0]); } else { print_subarray(os, a, 0, 0); } diff --git a/mlx/utils.h b/mlx/utils.h index 108fdf2030..c23eb03298 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -51,7 +51,7 @@ struct PrintFormatter { bool capitalize_bool{false}; }; -extern PrintFormatter global_formatter; +PrintFormatter& GetGlobalFormatter(); /** The type from promoting the arrays' types with one another. */ inline Dtype result_type(const array& a, const array& b) { diff --git a/python/src/array.cpp b/python/src/array.cpp index e518f27651..8732a143e2 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -85,7 +85,7 @@ class ArrayPythonIterator { void init_array(nb::module_& m) { // Set Python print formatting options - mlx::core::global_formatter.capitalize_bool = true; + GetGlobalFormatter().capitalize_bool = true; // Types nb::class_( From bd1508e787f9c805019fae718c2ba81ba9ed0141 Mon Sep 17 00:00:00 2001 From: Cheng Date: Tue, 10 Dec 2024 10:47:50 +0900 Subject: [PATCH 2/2] Function name style --- mlx/utils.cpp | 6 +++--- mlx/utils.h | 2 +- python/src/array.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlx/utils.cpp b/mlx/utils.cpp index b93faee6ec..6d05ad5f8d 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -56,7 +56,7 @@ inline void PrintFormatter::print(std::ostream& os, complex64_t val) { os << val; } -PrintFormatter& GetGlobalFormatter() { +PrintFormatter& get_global_formatter() { static PrintFormatter formatter; return formatter; } @@ -174,7 +174,7 @@ void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { i = n - num_print - 1; index += s * (n - 2 * num_print - 1); } else if (is_last) { - GetGlobalFormatter().print(os, a.data()[index]); + get_global_formatter().print(os, a.data()[index]); } else { print_subarray(os, a, index, dim + 1); } @@ -190,7 +190,7 @@ void print_array(std::ostream& os, const array& a) { os << "array("; if (a.ndim() == 0) { auto data = a.data(); - GetGlobalFormatter().print(os, data[0]); + get_global_formatter().print(os, data[0]); } else { print_subarray(os, a, 0, 0); } diff --git a/mlx/utils.h b/mlx/utils.h index c23eb03298..04f59feaa0 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -51,7 +51,7 @@ struct PrintFormatter { bool capitalize_bool{false}; }; -PrintFormatter& GetGlobalFormatter(); +PrintFormatter& get_global_formatter(); /** The type from promoting the arrays' types with one another. */ inline Dtype result_type(const array& a, const array& b) { diff --git a/python/src/array.cpp b/python/src/array.cpp index 8732a143e2..017fb6e91a 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -85,7 +85,7 @@ class ArrayPythonIterator { void init_array(nb::module_& m) { // Set Python print formatting options - GetGlobalFormatter().capitalize_bool = true; + get_global_formatter().capitalize_bool = true; // Types nb::class_(