-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ADD] add skip layernorm to kernel explorer for ROCm EP (#12816)
**Description**: Describe your changes. Related PR: #12803 #12817 #12821 Add skip layernorm to kernel explorer for profiling. **Motivation and Context** - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here.
- Loading branch information
1 parent
ffeba98
commit 189aef2
Showing
9 changed files
with
407 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
onnxruntime/contrib_ops/rocm/bert/skip_layer_norm_impl_kernel.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <hip/hip_fp16.h> | ||
#include "contrib_ops/rocm/bert/layer_norm.cuh" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace rocm { | ||
|
||
template <typename T> | ||
T maybe2half(float x); | ||
|
||
template <> | ||
float maybe2half(float x) { | ||
return x; | ||
} | ||
|
||
template <> | ||
half maybe2half(float x) { | ||
return __float2half_rn(x); | ||
} | ||
|
||
template <typename T, unsigned TPB> | ||
__global__ void SkipLayerNormKernel( | ||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, const T* bias, | ||
const T epsilon, T* output) { | ||
const T reverse_ld = T(1.f / ld); | ||
const int offset = blockIdx.x * ld; | ||
|
||
KeyValuePairSum pair_sum; | ||
// reduce x and x^2 | ||
hipcub::KeyValuePair<T, T> thread_data(0, 0); | ||
|
||
for (int i = threadIdx.x; i < ld; i += TPB) { | ||
const int idx = offset + i; | ||
const T val = (bias == nullptr) ? input[idx] + skip[idx] : input[idx] + skip[idx] + bias[i]; | ||
const T rldval = reverse_ld * val; | ||
thread_data = pair_sum(thread_data, hipcub::KeyValuePair<T, T>(rldval, rldval * val)); | ||
output[idx] = val; | ||
} | ||
|
||
LayerNorm<T, TPB>(thread_data, ld, offset, beta, gamma, epsilon, output); | ||
} | ||
|
||
// Vectorized kernel | ||
template <typename T, unsigned TPB, int ILP> | ||
__global__ void SkipLayerNormKernelSmall( | ||
const int ld, const T* input, const T* skip, const T* beta, const T* gamma, | ||
const T* bias, const T epsilon, T* output, bool hasBias) { | ||
const T rld = T(1.f / ld); | ||
const int idx = blockIdx.x * ld + threadIdx.x * ILP; // grid_size = n / ld | ||
|
||
using VecT = aligned_vector<T, ILP>; | ||
T input_v[ILP], skip_v[ILP], bias_v[ILP]; | ||
|
||
hipcub::KeyValuePair<T, T> thread_data(T(0.f), T(0.f)); | ||
|
||
if (ILP * threadIdx.x < ld) { | ||
VecT* input_val = reinterpret_cast<VecT*>(&input_v); | ||
*input_val = *reinterpret_cast<const VecT*>(&input[idx]); | ||
|
||
VecT* skip_val = reinterpret_cast<VecT*>(&skip_v); | ||
*skip_val = *reinterpret_cast<const VecT*>(&skip[idx]); | ||
|
||
if (hasBias) { | ||
VecT* bias_val = reinterpret_cast<VecT*>(&bias_v); | ||
*bias_val = *reinterpret_cast<const VecT*>(&bias[threadIdx.x * ILP]); | ||
} | ||
|
||
T rldval_sum = T(0.f); | ||
T rldvalsq_sum = T(0.f); | ||
#pragma unroll | ||
for (int i = 0; i < ILP; i++) { | ||
input_v[i] += hasBias ? skip_v[i] + bias_v[i] : skip_v[i]; | ||
const T rldval = rld * input_v[i]; | ||
rldval_sum += rldval; | ||
rldvalsq_sum += rldval * input_v[i]; | ||
} | ||
thread_data = hipcub::KeyValuePair<T, T>(rldval_sum, rldvalsq_sum); | ||
} | ||
LayerNormSmall<T, TPB, ILP>(input_v, thread_data, ld, idx, beta, gamma, epsilon, output); | ||
} | ||
|
||
} // namespace rocm | ||
} // namespace contrib | ||
} // namespace onnxruntime |
Oops, something went wrong.