Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA/ROCm] Conditionally support ArgMax and ArgMin for opset 12 and above #22713

Merged
merged 8 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,12 @@ Do not modify directly.*
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||10|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
63 changes: 60 additions & 3 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,13 @@
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);

class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin);

// OpSet 13
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
Expand Down Expand Up @@ -1199,6 +1206,13 @@
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin);

// OpSet 14
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu);
Expand Down Expand Up @@ -1640,6 +1654,9 @@
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,

Check warning on line 1657 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1657: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,

Check warning on line 1658 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1658: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,

Check warning on line 1659 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1659: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1)>,
Expand Down Expand Up @@ -1822,9 +1839,6 @@
19, IsInf)>,

// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
Expand Down Expand Up @@ -1916,6 +1930,13 @@
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax)>,

Check warning on line 1933 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1933: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax)>,

Check warning on line 1934 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1934: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax)>,

Check warning on line 1935 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1935: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin)>,

Check warning on line 1936 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1936: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin)>,

Check warning on line 1937 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1937: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin)>,

Check warning on line 1938 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:1938: Lines should be <= 120 characters long [whitespace/line_length] [2]

// OpSet 13
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
Expand Down Expand Up @@ -2150,6 +2171,13 @@
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,

Check warning on line 2174 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:2174: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin)>,

// OpSet 14
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu)>,
Expand Down Expand Up @@ -2566,6 +2594,32 @@
return false;
}

static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) {
// Opset 12 introduced the attribute "select_last_index"
if (node.SinceVersion() >= 12) {
const auto& node_attributes = node.GetAttributes();

for (auto& attr : node_attributes) {
auto& attr_name = attr.first;
auto& attr_value = attr.second;

// CuDNN doesn't support picking the last index in case of encountering
// duplicate max values.
// CuDNN's API doc doesn't mention what happens in case duplicates are encountered,
// but based on testing, the results seem to indicate a "stable" implementation
// (i.e.) relative ordering is preserved which is the expected behavior when the
// attribute takes on the default value (most common use-case for this operator).
if ("select_last_index" == attr_name) {
if (attr_value.i() != 0) {
return true;
}
}
}
}

return false;
}

std::unique_ptr<onnxruntime::IDataTransfer> CUDAExecutionProvider::GetDataTransfer() const {
return std::make_unique<onnxruntime::GPUDataTransfer>();
}
Expand Down Expand Up @@ -2615,6 +2669,9 @@
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred());
force_inside = !not_supported;
} else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) {
not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node);
force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
// cast is not compute heavy, and may be placed outside
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {

#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, end, \
begin, end, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
Expand All @@ -37,8 +37,13 @@ namespace cuda {
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)

#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)

// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template <bool allow_multi_axes>
Expand Down Expand Up @@ -829,14 +834,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, CUDNN_REDUCE_TENSOR_NO

} // namespace ReductionOps

// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)

REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
Expand Down
20 changes: 18 additions & 2 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase<allow_multi_axes
template <typename T>
class ArgMax final : public ReduceKernel<false> {
public:
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {
// The following is just a safety check.
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax
// nodes with select_last_index == 1 to the CUDA EP.
int64_t select_last_index = 0;
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
}
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MAX);
Expand All @@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel<false> {
template <typename T>
class ArgMin final : public ReduceKernel<false> {
public:
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {
// The following is just a safety check.
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin
// nodes with select_last_index == 1 to the CUDA EP.
int64_t select_last_index = 0;
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
}
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MIN);
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace rocm {

#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, end, \
begin, end, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
Expand All @@ -37,8 +37,13 @@ namespace rocm {
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)

#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)

// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template <bool allow_multi_axes>
Expand Down Expand Up @@ -830,14 +835,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N

} // namespace ReductionOps

// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)

REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
Expand Down
Loading
Loading