Skip to content

Commit

Permalink
validate inputs from prepack instead of ctx if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Nov 6, 2024
1 parent d5850a4 commit bc2c820
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
26 changes: 17 additions & 9 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,22 @@ SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
prepacked_skip_fp32_data_(nullptr),
prepacked_gamma_fp32_data_(nullptr),
prepacked_beta_fp32_data_(nullptr),
prepacked_bias_fp32_data_(nullptr) {
prepacked_bias_fp32_data_(nullptr),
prepacked_skip_tensor_(nullptr),
prepacked_gamma_tensor_(nullptr),
prepacked_beta_tensor_(nullptr),
prepacked_bias_tensor_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}

template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input<Tensor>(0);
const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3);
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(4);
const Tensor* skip = prepacked_skip_fp32_data_ ? prepacked_skip_tensor_ : p_ctx->Input<Tensor>(1);
const Tensor* gamma = prepacked_gamma_fp32_data_ ? prepacked_gamma_tensor_ : p_ctx->Input<Tensor>(2);
const Tensor* beta = prepacked_beta_fp32_data_ ? prepacked_beta_tensor_ : p_ctx->Input<Tensor>(3);
const Tensor* bias = prepacked_bias_fp32_data_ ? prepacked_bias_tensor_ : p_ctx->Input<Tensor>(4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());
Expand All @@ -215,10 +219,10 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);

const T* input_data = input->Data<T>();
const T* skip_data = skip == nullptr ? nullptr : skip->Data<T>();
const T* gamma_data = gamma == nullptr ? nullptr : gamma->Data<T>();
const T* beta_data = beta == nullptr ? nullptr : beta->Data<T>();
const T* bias_data = bias == nullptr ? nullptr : bias->Data<T>();
const T* skip_data = prepacked_skip_fp32_data_ ? nullptr : skip->Data<T>(); // skip is mandatory
const T* gamma_data = prepacked_gamma_fp32_data_ ? nullptr : gamma->Data<T>(); // gamma is mandatory
const T* beta_data = (prepacked_beta_fp32_data_ || beta == nullptr) ? nullptr : beta->Data<T>();
const T* bias_data = (prepacked_bias_fp32_data_ || bias == nullptr) ? nullptr : bias->Data<T>();

T* output_data = output->MutableData<T>();

Expand Down Expand Up @@ -295,12 +299,16 @@ Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx
is_packed = false;
if (input_idx == 1) { // skip
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed);
prepacked_skip_tensor_ = &tensor;
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed);
prepacked_gamma_tensor_ = &tensor;
} else if (input_idx == 3) { // beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
prepacked_beta_tensor_ = &tensor;
} else if (input_idx == 4) { // bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
prepacked_bias_tensor_ = &tensor;
}

return Status::OK();
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class SkipLayerNorm final : public OpKernel {
IAllocatorUniquePtr<float> prepacked_gamma_fp32_data_;
IAllocatorUniquePtr<float> prepacked_beta_fp32_data_;
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
const Tensor* prepacked_skip_tensor_;
const Tensor* prepacked_gamma_tensor_;
const Tensor* prepacked_beta_tensor_;
const Tensor* prepacked_bias_tensor_;
};

} // namespace contrib
Expand Down

0 comments on commit bc2c820

Please sign in to comment.