Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work around function std name collision in MSVC #1679

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlx/mlx.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "mlx/fft.h"
#include "mlx/io.h"
#include "mlx/linalg.h"
#include "mlx/ops.h"
#include "mlx/ops_public.h"
#include "mlx/random.h"
#include "mlx/stream.h"
#include "mlx/transforms.h"
Expand Down
10 changes: 5 additions & 5 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,17 +1704,17 @@ array var(
return var(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
}

array std(
array std_dev(
const array& a,
bool keepdims,
int ddof /* = 0*/,
StreamOrDevice s /* = {}*/) {
std::vector<int> axes(a.ndim());
std::iota(axes.begin(), axes.end(), 0);
return std(a, axes, keepdims, ddof, to_stream(s));
return std_dev(a, axes, keepdims, ddof, to_stream(s));
}

array std(
array std_dev(
const array& a,
const std::vector<int>& axes,
bool keepdims /* = false */,
Expand All @@ -1723,13 +1723,13 @@ array std(
return sqrt(var(a, axes, keepdims, ddof, s), s);
}

array std(
array std_dev(
const array& a,
int axis,
bool keepdims /* = false */,
int ddof /* = 0*/,
StreamOrDevice s /* = {} */) {
return std(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
return std_dev(a, std::vector<int>{axis}, keepdims, ddof, to_stream(s));
}

array prod(const array& a, bool keepdims, StreamOrDevice s /* = {}*/) {
Expand Down
14 changes: 9 additions & 5 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,18 @@ array var(
StreamOrDevice s = {});

/** Computes the standard deviation of the elements of an array. */
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array std(const array& a, StreamOrDevice s = {}) {
return std(a, false, 0, to_stream(s));
array std_dev(
const array& a,
bool keepdims,
int ddof = 0,
StreamOrDevice s = {});
inline array std_dev(const array& a, StreamOrDevice s = {}) {
return std_dev(a, false, 0, to_stream(s));
}

/** Computes the standard deviatoin of the elements of an array along the given
* axes */
array std(
array std_dev(
const array& a,
const std::vector<int>& axes,
bool keepdims = false,
Expand All @@ -559,7 +563,7 @@ array std(

/** Computes the standard deviation of the elements of an array along the given
* axis */
array std(
array std_dev(
const array& a,
int axis,
bool keepdims = false,
Expand Down
18 changes: 18 additions & 0 deletions mlx/ops_public.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright © 2023-2024 Apple Inc.

#pragma once

#include "mlx/ops.h"

namespace mlx::core {

// The "std" function has a name collision with "namespace std" in MSVC after
// "using namespace mlx::core", to work around it in our python bindings code
// we use "std_dev" as name instead and only expose the "std" function in
// public header.
template <typename... Args>
inline array std(Args&&... args) {
return std_dev(std::forward<Args>(args)...);
}

} // namespace mlx::core
2 changes: 1 addition & 1 deletion python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ void init_array(nb::module_& m) {
bool keepdims,
int ddof,
StreamOrDevice s) {
return mlx::core::std(
return std_dev(
a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
},
"axis"_a = nb::none(),
Expand Down
2 changes: 1 addition & 1 deletion python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2384,7 +2384,7 @@ void init_ops(nb::module_& m) {
bool keepdims,
int ddof,
StreamOrDevice s) {
return mlx::core::std(
return std_dev(
a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s);
},
nb::arg(),
Expand Down