Skip to content

Commit

Permalink
[ADD] add skip layernorm to kernel explorer for ROCm EP (#12816)
Browse files Browse the repository at this point in the history
**Description**: Describe your changes.
Related PR: #12803
#12817
#12821

Add skip layernorm to kernel explorer for profiling.

**Motivation and Context**
- Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here.
  • Loading branch information
PeixuanZuo authored Sep 20, 2022
1 parent ffeba98 commit 189aef2
Show file tree
Hide file tree
Showing 9 changed files with 407 additions and 127 deletions.
25 changes: 12 additions & 13 deletions onnxruntime/contrib_ops/rocm/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/skip_layer_norm.h"

#include "core/providers/rocm/rocm_common.h"
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"

namespace onnxruntime {
Expand Down Expand Up @@ -93,21 +94,19 @@ Status SkipLayerNorm<T>::ComputeInternal(OpKernelContext* ctx) const {
int sequence_length = static_cast<int>(input_dims[1]);
int hidden_size = static_cast<int>(input_dims[2]);
int64_t element_count = input_dims[0] * sequence_length * hidden_size;
size_t element_size = sizeof(T);
typedef typename ToHipType<T>::MappedType HipT;

return LaunchSkipLayerNormKernel<HipT>(
Stream(),
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
reinterpret_cast<const HipT*>(skip->Data<T>()),
reinterpret_cast<const HipT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const HipT*>(beta->Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
static_cast<int>(element_count),
element_size);
Stream(),
reinterpret_cast<HipT*>(output->MutableData<T>()),
reinterpret_cast<const HipT*>(input->Data<T>()),
reinterpret_cast<const HipT*>(skip->Data<T>()),
reinterpret_cast<const HipT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const HipT*>(beta->Data<T>()) : nullptr,
(bias != nullptr) ? reinterpret_cast<const HipT*>(bias->Data<T>()) : nullptr,
epsilon_,
hidden_size,
static_cast<int>(element_count));
}

} // namespace rocm
Expand Down
135 changes: 31 additions & 104 deletions onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,149 +28,76 @@ limitations under the License.
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Licensed under the MIT License.

#include <hip/hip_fp16.h>
#include "contrib_ops/rocm/bert/layer_norm.cuh"
#include "contrib_ops/rocm/bert/skip_layer_norm_impl.h"

#include <hip/hip_fp16.h>

#include "contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace rocm {

template <typename T>
T maybe2half(float x);

template <>
float maybe2half(float x) {
return x;
}

template <>
half maybe2half(float x) {
return __float2half_rn(x);
}

template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;

KeyValuePairSum pair_sum;
// reduce x and x^2
hipcub::KeyValuePair<T, T> thread_data(0, 0);

for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
output[idx] = val;
}

LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
}

// Vectorized kernel
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld

using VecT = aligned_vector<T, ILP>;
T input_v[ILP], skip_v[ILP], bias_v[ILP];

hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));

if (ILP * threadIdx.x < ld) {
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);

VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);

if (hasBias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}

T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
const T rldval = rld * input_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
}
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
}

template <typename T>
Status LaunchSkipLayerNormKernel(
hipStream_t stream, T* output, const T* input, const T* skip, const T* gamma,
const T* beta, const T* bias, float epsilon, const int ld, const int element_count,
size_t element_size) {
const T* beta, const T* bias, float epsilon, const int ld, const int element_count) {
// this must be true because n is the total size of the tensor
assert(element_count % ld == 0);
bool hasBias = (bias == nullptr) ? false : true;
if (0 == (ld % 4)) {
const int grid_size = element_count / ld;
if (ld <= 32) {
constexpr int block_size = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 64) {
constexpr int block_size = 64 / 2;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 2>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 2><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 128) {
constexpr int block_size = 128 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 384) {
constexpr int block_size = 384 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 768) {
constexpr int block_size = 768 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 1024) {
constexpr int block_size = 1024 / 4;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 4>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 4><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else {
constexpr int block_size = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
}
} else {
const int grid_size = element_count / ld;
if (ld <= 32) {
constexpr int block_size = 32;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 64) {
constexpr int block_size = 64;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld <= 128) {
constexpr int block_size = 128;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else if (ld == 384) {
constexpr int block_size = 384;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernelSmall<T, block_size, 1>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
SkipLayerNormKernelSmall<T, block_size, 1><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output, hasBias);
} else {
constexpr int block_size = 256;
hipLaunchKernelGGL(HIP_KERNEL_NAME(SkipLayerNormKernel<T, block_size>), grid_size, block_size,
0, stream, ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
SkipLayerNormKernel<T, block_size><<<grid_size, block_size, 0, stream>>>(
ld, input, skip, beta, gamma, bias, maybe2half<T>(epsilon), output);
}
}
return HIP_CALL(hipPeekAtLastError());
Expand All @@ -179,12 +106,12 @@ Status LaunchSkipLayerNormKernel(
template Status LaunchSkipLayerNormKernel<float>(hipStream_t stream, float* output, const float* input,
const float* skip, const float* gamma, const float* beta,
const float* bias, float epsilon, const int ld,
const int element_count, size_t element_size);
const int element_count);

template Status LaunchSkipLayerNormKernel<half>(hipStream_t stream, half* output, const half* input,
const half* skip, const half* gamma, const half* beta,
const half* bias, float epsilon, const int ld,
const int element_count, size_t element_size);
const int element_count);

} // namespace rocm
} // namespace contrib
Expand Down
19 changes: 9 additions & 10 deletions onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@ namespace rocm {
template <typename T>
Status LaunchSkipLayerNormKernel(
hipStream_t stream,
T* output, // output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count, // number of elements in input tensor
size_t element_size
T* output, // output tensor
const T* input, // input tensor
const T* skip, // skip tensor
const T* gamma, // Layer normalization gamma tensor
const T* beta, // Layer normalization beta tensor
const T* bias, // Layer normalization beta tensor
float epsilon, // Layer normalization epsilon
int hidden_size, // hidden size, it is the leading dimension (ld)
int element_count // number of elements in input tensor
);

} // namespace rocm
Expand Down
89 changes: 89 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <hip/hip_fp16.h>
#include "contrib_ops/rocm/bert/layer_norm.cuh"

namespace onnxruntime {
namespace contrib {
namespace rocm {

template <typename T>
T maybe2half(float x);

template <>
float maybe2half(float x) {
return x;
}

template <>
half maybe2half(float x) {
return __float2half_rn(x);
}

template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias,
const T epsilon, T* output) {
const T reverse_ld = T(1.f / ld);
const int offset = blockIdx.x * ld;

KeyValuePairSum pair_sum;
// reduce x and x^2
hipcub::KeyValuePair<T, T> thread_data(0, 0);

for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i];
const T rldval = reverse_ld * val;
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val));
output[idx] = val;
}

LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output);
}

// Vectorized kernel
template <typename T, unsigned TPB, int ILP>
__global__ void SkipLayerNormKernelSmall(
const int ld, const T* input, const T* skip, const T* beta, const T* gamma,
const T* bias, const T epsilon, T* output, bool hasBias) {
const T rld = T(1.f / ld);
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld

using VecT = aligned_vector<T, ILP>;
T input_v[ILP], skip_v[ILP], bias_v[ILP];

hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f));

if (ILP * threadIdx.x < ld) {
VecT* input_val = reinterpret_cast<VecT*>(&input_v);
*input_val = *reinterpret_cast<const VecT*>(&input[idx]);

VecT* skip_val = reinterpret_cast<VecT*>(&skip_v);
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]);

if (hasBias) {
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v);
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]);
}

T rldval_sum = T(0.f);
T rldvalsq_sum = T(0.f);
#pragma unroll
for (int i = 0; i < ILP; i++) {
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i];
const T rldval = rld * input_v[i];
rldval_sum += rldval;
rldvalsq_sum += rldval * input_v[i];
}
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum);
}
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output);
}

} // namespace rocm
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit 189aef2

Please sign in to comment.