From bc2c8207bde69de1de426bba10b8783401ef1d99 Mon Sep 17 00:00:00 2001 From: Alex Marin Date: Wed, 6 Nov 2024 09:27:11 -0800 Subject: [PATCH] validate inputs from prepack instead of ctx if needed --- .../contrib_ops/cpu/skip_layer_norm.cc | 26 ++++++++++++------- onnxruntime/contrib_ops/cpu/skip_layer_norm.h | 4 +++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index 5bffe865b9c99..ded6992374fe6 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -184,7 +184,11 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) prepacked_skip_fp32_data_(nullptr), prepacked_gamma_fp32_data_(nullptr), prepacked_beta_fp32_data_(nullptr), - prepacked_bias_fp32_data_(nullptr) { + prepacked_bias_fp32_data_(nullptr), + prepacked_skip_tensor_(nullptr), + prepacked_gamma_tensor_(nullptr), + prepacked_beta_tensor_(nullptr), + prepacked_bias_tensor_(nullptr) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } @@ -192,10 +196,10 @@ SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) template Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); - const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input(1); - const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input(2); - const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input(3); - const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(4); + const Tensor* skip = prepacked_skip_fp32_data_ ? prepacked_skip_tensor_ : p_ctx->Input(1); + const Tensor* gamma = prepacked_gamma_fp32_data_ ? prepacked_gamma_tensor_ : p_ctx->Input(2); + const Tensor* beta = prepacked_beta_fp32_data_ ? prepacked_beta_tensor_ : p_ctx->Input(3); + const Tensor* bias = prepacked_bias_fp32_data_ ? prepacked_bias_tensor_ : p_ctx->Input(4); Tensor* output = p_ctx->Output(0, input->Shape()); // For inferencing, we support one more optional output which is the sum of the input and skip tensors Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape()); @@ -215,10 +219,10 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1); const T* input_data = input->Data(); - const T* skip_data = skip == nullptr ? nullptr : skip->Data(); - const T* gamma_data = gamma == nullptr ? nullptr : gamma->Data(); - const T* beta_data = beta == nullptr ? nullptr : beta->Data(); - const T* bias_data = bias == nullptr ? nullptr : bias->Data(); + const T* skip_data = prepacked_skip_fp32_data_ ? nullptr : skip->Data(); // skip is mandatory + const T* gamma_data = prepacked_gamma_fp32_data_ ? nullptr : gamma->Data(); // gamma is mandatory + const T* beta_data = (prepacked_beta_fp32_data_ || beta == nullptr) ? nullptr : beta->Data(); + const T* bias_data = (prepacked_bias_fp32_data_ || bias == nullptr) ? nullptr : bias->Data(); T* output_data = output->MutableData(); @@ -295,12 +299,16 @@ Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx is_packed = false; if (input_idx == 1) { // skip ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed); + prepacked_skip_tensor_ = &tensor; } else if (input_idx == 2) { // gamma ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed); + prepacked_gamma_tensor_ = &tensor; } else if (input_idx == 3) { // beta ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); + prepacked_beta_tensor_ = &tensor; } else if (input_idx == 4) { // bias ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); + prepacked_bias_tensor_ = &tensor; } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index fcbb00ee93938..fd6480ad239c8 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -25,6 +25,10 @@ class SkipLayerNorm final : public OpKernel { IAllocatorUniquePtr prepacked_gamma_fp32_data_; IAllocatorUniquePtr prepacked_beta_fp32_data_; IAllocatorUniquePtr prepacked_bias_fp32_data_; + const Tensor* prepacked_skip_tensor_; + const Tensor* prepacked_gamma_tensor_; + const Tensor* prepacked_beta_tensor_; + const Tensor* prepacked_bias_tensor_; }; } // namespace contrib