Skip to content

Commit

Permalink
Revert "Enable ROCm to use tunable GEMM" (#13160)
Browse files Browse the repository at this point in the history
Reverts #12853 due to CI pipeline problem.
  • Loading branch information
cloudhan authored and linnealovespie committed Sep 30, 2022
1 parent 3565922 commit 1f18f65
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 213 deletions.
3 changes: 0 additions & 3 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1418,9 +1418,6 @@ if (onnxruntime_USE_ROCM)
#endif()
endif()

include(composable_kernel)
target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_composable_kernel_includes device_gemm_instance)

if(UNIX)
set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections")
target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync_cpp)
Expand Down
22 changes: 10 additions & 12 deletions onnxruntime/core/providers/rocm/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,13 @@
// Licensed under the MIT License.

#include "core/providers/rocm/math/gemm.h"

#include "core/providers/cpu/math/gemm_helper.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/fpgeneric.h"
#include "core/providers/rocm/tunable/gemm.h"


namespace onnxruntime {
namespace rocm {

using tunable::blas::BlasOp;

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
Expand Down Expand Up @@ -127,21 +122,24 @@ Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
}
}

return tunable::blas::column_major::Gemm(
false, Stream(),
HipT alpha = ToHipType<T>::FromFloat(alpha_);
HipT beta = ToHipType<T>::FromFloat(beta_);
// Gemm, note that HIP assumes col-major, so Y(N,M) = alpha * op(W) x op(X) + beta * Y
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
RocblasHandle(),
trans_B_ ? BlasOp::Trans : BlasOp::NonTrans,
trans_A_ ? BlasOp::Trans : BlasOp::NonTrans,
trans_B_ ? rocblas_operation_transpose : rocblas_operation_none,
trans_A_ ? rocblas_operation_transpose : rocblas_operation_none,
N, M, K,
alpha_,
&alpha,
reinterpret_cast<const HipT*>(W->Data<T>()),
(trans_B_ ? K : N),
reinterpret_cast<const HipT*>(X->Data<T>()),
(trans_A_ ? M : K),
// ideally we need to set the output buffer contents to 0 if bias is missing,
// but passing 0 for beta is cheaper and it will ignore any junk in the output buffer
B != nullptr ? beta_ : 0.0f,
out_data, N);
B != nullptr ? &beta : &zero,
out_data, N));
return Status::OK();
}

} // namespace rocm
Expand Down
16 changes: 6 additions & 10 deletions onnxruntime/core/providers/rocm/math/matmul_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include "core/providers/rocm/rocm_allocator.h"
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/rocm/tunable/gemm.h"

namespace onnxruntime {
namespace rocm {
Expand Down Expand Up @@ -84,16 +83,13 @@ Status MatMulImpl(const RocmKernel* op, MatMulComputeHelper& helper,
int64_t stride_A, stride_B, stride_C, batch_count;

if (helper.OutputOffsets().size() == 1) {
using tunable::blas::BlasOp;
BlasOp transA = transa ? BlasOp::Trans : BlasOp::NonTrans;
BlasOp transB = transb ? BlasOp::Trans : BlasOp::NonTrans;
return tunable::blas::column_major::Gemm(
false, op->Stream(),
op->RocblasHandle(), transB, transA, static_cast<int64_t>(helper.N()),
static_cast<int64_t>(helper.M()), static_cast<int64_t>(helper.K()), t_alpha,
ROCBLAS_RETURN_IF_ERROR(rocblasGemmHelper(
op->RocblasHandle(), transB, transA, static_cast<int>(helper.N()),
static_cast<int>(helper.M()), static_cast<int>(helper.K()), &alpha,
reinterpret_cast<const HipT*>(right_x_data), ldb,
reinterpret_cast<const HipT*>(left_x_data), lda, t_zero,
reinterpret_cast<HipT*>(output_y_data), ldc);
reinterpret_cast<const HipT*>(left_x_data), lda, &zero,
reinterpret_cast<HipT*>(output_y_data), ldc));
return Status::OK();
} else if (CanUseStridedBatchedGemm(left_shape, right_shape,
transa, transb, trans_batch_a, trans_batch_b,
stride_A, stride_B, stride_C, batch_count)) {
Expand Down
119 changes: 0 additions & 119 deletions onnxruntime/core/providers/rocm/tunable/gemm.cu

This file was deleted.

58 changes: 0 additions & 58 deletions onnxruntime/core/providers/rocm/tunable/gemm.h

This file was deleted.

11 changes: 0 additions & 11 deletions onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ struct DataTypeAdaptor<half> {
using type = ck::half_t;
};

template <>
struct DataTypeAdaptor<BFloat16> {
using type = ck::bhalf16_t;
};

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

Expand All @@ -55,12 +50,6 @@ auto GetCKGemmTypeStringAndOps() {
auto type_string = impl->GetTypeString();
auto invoker = impl->MakeInvokerPointer();
auto ck_gemm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GemmParams<T>* params) -> Status {
auto one = ToHipType<T>::FromFloat(1.0f);
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPOTED_ARGUMENT_IF(
params->alpha != one || params->beta != zero,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature());

auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
params->m, params->n, params->k,
Expand Down

0 comments on commit 1f18f65

Please sign in to comment.