diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 813b3ea5ec8a4..7762173e98e69 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -244,7 +244,7 @@ Status QOrderedAttention::ComputeInternal(OpKernelContext* context) const { int64_t size_of_attention_scores = ((int64_t)batch_size) * num_heads_ * sequence_length * sequence_length; // transposed qkv_layer, union(stacked, attention probs + attention scores) - auto gemm_buffer_quantized = GetScratchBuffer(m * n + std::max((int64_t)m * n, 2 * size_of_attention_scores)); + auto gemm_buffer_quantized = GetScratchBuffer((int64_t)m * n + std::max((int64_t)m * n, 2 * size_of_attention_scores)); int8_t* stacked_qkv_layers = gemm_buffer_quantized.get() + ((int64_t)m * n); int8_t* tranposed_qkv_layers = gemm_buffer_quantized.get(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc index 263b914179250..0aac49ade69af 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_longformer_attention.cc @@ -119,7 +119,7 @@ QOrderedLongformerAttention::ComputeInternal(OpKernelContext* context) const { // TODO: only calculate once per model. // Build Global Index - auto global_index_buffer = GetScratchBuffer(batch_size * sequence_length); + auto global_index_buffer = GetScratchBuffer(static_cast(batch_size) * static_cast(sequence_length)); auto batch_global_num_buffer = GetScratchBuffer(batch_size); size_t global_scratch_bytes = GetGlobalScratchSize(sequence_length); diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 18e91a98c21a0..7d851beda10d0 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -104,7 +104,7 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(pads_size == 2 * static_cast(dimension_count), "Pads tensor size should be equal to twice the input dimension count "); - pads.reserve(2 * dimension_count); + pads.reserve(2LL * dimension_count); for (size_t i = 0; i < pads_size; ++i) { pads.push_back(pads_tensor_raw_data[i]); } diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index b197a87e4b93d..38a8d6aa8611d 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -12,17 +12,17 @@ namespace onnxruntime { namespace test { -static const int64_t batch_size = 1; -static const int64_t sequence_len = 16; -static const int64_t input_hidden_size = 32; -static const int64_t num_heads = 2; -static const int64_t head_size = 16; -static const int64_t hidden_size = num_heads * head_size; +static constexpr int64_t batch_size = 1; +static constexpr int64_t sequence_len = 16; +static constexpr int64_t input_hidden_size = 32; +static constexpr int64_t num_heads = 2; +static constexpr int64_t head_size = 16; +static constexpr int64_t hidden_size = num_heads * head_size; static std::vector input_mask = { // [1, 16] 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0}; -static float input_scale = 0.025f; +static constexpr float input_scale = 0.025f; static std::vector inputq = { // [1, 16, 32] -33, 7, -54, 29, 14, 6, 14, 16, 1, 16, 22, 0, 16, 49, -14, -15, 68, 11, -18, -9, -42, 6, 6, 58, 22, 31, 0, -13, 42, 40, 4, 0, @@ -180,13 +180,13 @@ static std::vector v_bias = { -1.5637541858090884f, 0.053171526292804416f, -1.5821961194911058f, -1.2062417346542489f, 0.23029741928149683f, -0.8920457050782132f, -0.06220760650838387f, 0.2942590084687021f, -0.4362228349183151f, -0.2344379226413643f, -0.586149329261036f, -1.5243876669794532f, 0.22378084867382358f, -1.715499198175354f, -1.3795418183607775f, -1.2237706022285266f}; -static float qlayer_scale = 0.250f; -static float klayer_scale = 0.250f; -static float vlayer_scale = 0.125f; +static constexpr float qlayer_scale = 0.250f; +static constexpr float klayer_scale = 0.250f; +static constexpr float vlayer_scale = 0.125f; -static float qk_scale = 0.5f; -static float probs_scale = 0.0078125f; -static float attn_out_scale = 0.05f; +static constexpr float qk_scale = 0.5f; +static constexpr float probs_scale = 0.0078125f; +static constexpr float attn_out_scale = 0.05f; static std::vector attn_out_q8 = { -39, 8, -75, 2, -69, -31, -42, -29, 44, 6, 0, -61, -102, 61, 28, 76, diff --git a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc index 73a51ab796245..fa32536e4d472 100644 --- a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc @@ -79,19 +79,19 @@ static void run_qordered_longformer_attention_op_test( } TEST(QOrderedTest, LongformerAttention_1x128x2x16_window_32) { - const float scale_input = 1.0f / 32.0f; - const float scale_weight = 1.0f / 64.0f; - const float scale_bias = 1.0f / 8.0f; - const float scale_qkv_gemm = 1.0f / 4.0f; - const float scale_global_weight = 1.0f / 64.0f; - const float scale_global_gemm = 1.0f / 4.0f; - const float scale_output = 1.0f / 8.0f; - const int64_t batch_size = 1; - const int64_t sequence_len = 128; - const int64_t num_heads = 2; - const int64_t head_size = 16; - const int64_t window = 32; - const int64_t input_hidden_size = 0; // same as hidden_size + constexpr float scale_input = 1.0f / 32.0f; + constexpr float scale_weight = 1.0f / 64.0f; + constexpr float scale_bias = 1.0f / 8.0f; + constexpr float scale_qkv_gemm = 1.0f / 4.0f; + constexpr float scale_global_weight = 1.0f / 64.0f; + constexpr float scale_global_gemm = 1.0f / 4.0f; + constexpr float scale_output = 1.0f / 8.0f; + constexpr int64_t batch_size = 1; + constexpr int64_t sequence_len = 128; + constexpr int64_t num_heads = 2; + constexpr int64_t head_size = 16; + constexpr int64_t window = 32; + constexpr int64_t input_hidden_size = 0; // same as hidden_size // Following code generate the input data vectors: (Keep it here in case) // #include @@ -154,7 +154,7 @@ TEST(QOrderedTest, LongformerAttention_1x128x2x16_window_32) { // debug_print(global_attention_mask.data(), batch_size, sequence_len, "global_attention_mask"); // float scale_output = 1.0f / 8.0f; - + //========inputq : 128x32 ============ std::vector inputq = {