From 6d8dfe729eff0a4d1f6ac8f79bee77009cb0ef30 Mon Sep 17 00:00:00 2001 From: andrea_mancini Date: Mon, 13 May 2024 11:24:36 +0000 Subject: [PATCH 1/8] Fix BeamSearch on T5 with sequence_as_input_ids (#20667) --- .../cpu/transformers/beam_search_impl_t5.h | 26 +++++++---- .../cpu/transformers/generation_shared.h | 1 + .../contrib_ops/cpu/transformers/sequences.cc | 4 ++ .../contrib_ops/cpu/transformers/sequences.h | 3 ++ .../cpu/transformers/subgraph_t5_decoder.cc | 42 ++++++++++++------ .../cpu/transformers/subgraph_t5_decoder.h | 3 +- .../transformers/generation_device_helper.cc | 18 ++++---- .../test/contrib_ops/beam_search_test.cc | 12 +++++ .../tiny_t5_with_sequence_input_ids.onnx | Bin 0 -> 7633 bytes 9 files changed, 76 insertions(+), 33 deletions(-) create mode 100644 onnxruntime/test/testdata/tiny_t5_with_sequence_input_ids.onnx diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 8f5cdc97f27e5..b67d003eaceeb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -258,7 +258,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches current_length, cpu_state.sequences, parameters->max_length, - decoder_subgraph_.has_decoder_masked_attention_)); + decoder_subgraph_.has_decoder_masked_attention_, + this->cuda_device_prop_ != nullptr)); if (decoder_subgraph_.past_present_share_buffer_) { decoder_fetches.reserve(static_cast(decoder_subgraph_.GetFirstPresentOutputIndex()) + @@ -302,17 +303,24 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches auto cur_len = std::to_string(current_length); dumper->Print("***CurrentLength", cur_len, true); - for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) { + for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) { dumper->Print("decoder_feeds", i, true); dumper->Print("", decoder_feeds[i]); } - auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers; - dumper->Print("past_sequence_length", offset, true); - dumper->Print("", decoder_feeds[offset]); - dumper->Print("beam_width", offset + 1, true); - dumper->Print("", decoder_feeds[offset + 1]); - dumper->Print("cache_redir", offset + 2, true); - dumper->Print("", decoder_feeds[offset + 2]); + for (int i = 0; i < decoder_subgraph_.num_layers; i++) { + int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i; + int self_value_idx = self_key_idx + 1; + dumper->Print("past_key_self", i, true); + dumper->Print("", decoder_feeds[self_key_idx]); + dumper->Print("past_value_self", i + 1, true); + dumper->Print("", decoder_feeds[self_value_idx]); + int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i; + int cross_value_idx = cross_key_idx + 1; + dumper->Print("past_key_cross", i, true); + dumper->Print("", decoder_feeds[cross_key_idx]); + dumper->Print("past_value_cross", i, true); + dumper->Print("", decoder_feeds[cross_value_idx]); + } #endif #ifdef DEBUG_NODE_INPUTS_OUTPUTS diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index b1dd55eb20f34..ea3f1b2d2b62a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -99,6 +99,7 @@ struct ISequences { virtual gsl::span GetCurrentDeviceSequences() const = 0; // Get all current beam_index sequences in one continuous block (to pass to CUDA) virtual gsl::span GetNextDeviceSequences() = 0; // Get all next beam_index sequences in one continuous block (to pass to CUDA) virtual int GetSequenceLength() const = 0; + virtual int GetMaxLength() const = 0; }; struct ILogitsProcessorList { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 723c271897a78..ecad146da6777 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -36,6 +36,10 @@ int Sequences::GetSequenceLength() const { return current_length_; } +int Sequences::GetMaxLength() const { + return max_length_; +} + #ifdef DEBUG_GENERATION void Sequences::PrintSequences(const IConsoleDumper* dumper) const { for (int i = 0; i < batch_beam_size_; i++) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 99c9474a2ca4d..53b352a9b56f3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -25,6 +25,9 @@ class Sequences : public ISequences { // Returns current sequence length. int GetSequenceLength() const override; + // Returns max sequence length. + int GetMaxLength() const override; + #ifdef DEBUG_GENERATION // Print the sequences to StdOut in debug mode void PrintSequences(const IConsoleDumper* dumper) const; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 4d61ce71c69be..37e225f8d8910 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -156,7 +156,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds( int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len, - bool need_cache_indir) { + bool need_cache_indir, + bool use_cuda) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); // Allocate subgraph inputs from same device as inputs of encoder subgraph. @@ -182,19 +183,34 @@ Status T5DecoderSubgraph::CreateInitialFeeds( stream, DeviceCopyDirection::hostToDevice)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + if (use_cuda) { + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + for (int i = 0; i < batch_beam_size; i++) { + long long batch_beam_stride = (long long)i * sequences.GetMaxLength(); + int seq_size = sequences.GetSequenceLength(); + gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); + gsl::span temp_input(input_ids_data + i * seq_size, seq_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + sequence, + stream, + DeviceCopyDirection::deviceToDevice)); + } + } else { + for (int i = 0; i < batch_beam_size; i++) { + gsl::span sequence = sequences.GetSequence(i); + const int32_t* sequence_data = sequence.data(); + long long seq_index = (long long)i * cur_len; + memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + } + gsl::span temp_input(input_ids_data, total_size); + gsl::span temp_sequence(seq_copy_ptr, total_size); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + temp_input, + temp_sequence, + stream, + DeviceCopyDirection::hostToDevice)); } - gsl::span temp_input(input_ids_data, total_size); - gsl::span temp_sequence(seq_copy_ptr, total_size); - ORT_RETURN_IF_ERROR(device_copy_int32_func( - temp_input, - temp_sequence, - stream, - DeviceCopyDirection::hostToDevice)); } // The ordering is the same as used in Setup. diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 83dae49c7dcbd..a72ce37a93aba 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -48,7 +48,8 @@ class T5DecoderSubgraph : public Subgraph { int cur_len, transformers::Sequences& sequences, int past_present_share_buffer_max_seq_len = -1, - bool need_cache_indir = false); + bool need_cache_indir = false, + bool use_cuda = false); Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 7adc2fe0a67ea..ff327da805344 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -1264,16 +1264,14 @@ Status UpdateDecoderFeeds( CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, cuda_stream)); } else { - for (int i = 0; i < batch_beam_size; i++) { - gsl::span sequence = sequences.GetSequence(i); - const int32_t* sequence_data = sequence.data(); - CUDA_RETURN_IF_ERROR( - cudaMemcpyAsync(input_ids_data + static_cast(i) * current_length, - sequence_data, - current_length * sizeof(int32_t), - cudaMemcpyHostToDevice, - cuda_stream)); - } + // We expect sequences to point directly to device memory + int max_length = sequences.GetMaxLength(); + auto sequences_buffer = sequences.GetCurrentDeviceSequences(); + CUDA_RETURN_IF_ERROR( + cudaMemcpy2DAsync(input_ids_data, current_length * sizeof(int32_t), + sequences_buffer.data(), max_length * sizeof(int32_t), + current_length * sizeof(int32_t), batch_beam_size, + cudaMemcpyDeviceToDevice, cuda_stream)); } next_inputs[0] = input_ids; diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 6ce9f5de68f11..de0d0747999c9 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -7,6 +7,8 @@ #include "core/common/gsl.h" #include "core/session/onnxruntime_cxx_api.h" #include "test/common/cuda_op_test_utils.h" +#include "test/providers/model_tester.h" +#include "test/util/include/current_test_name.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" @@ -388,5 +390,15 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { } } +TEST(BeamSearchTest, T5WithSequenceInputIds) { + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/tiny_t5_with_sequence_input_ids.onnx")); + tester.AddInput("encoder_input_ids", {1, 5}, {16, 9, 16, 14, 15}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 1, 6, 1, 6, 1, 6, 1, 6, 1, 2, 1, 6, 1, 6, 1, 6, 1, 1, 6, 2, 1, 6, 1, 6, 1, 1, 6, 1, 6}); +#ifdef USE_CUDA + tester.ConfigEp(DefaultCudaExecutionProvider()); +#endif + tester.RunWithConfig(); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/tiny_t5_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/tiny_t5_with_sequence_input_ids.onnx new file mode 100644 index 0000000000000000000000000000000000000000..f168df3246baf9c41eee656539d9d7d563e1d29d GIT binary patch literal 7633 zcmeHMYgALm7A6oN44`m$Qy<7f1VK>ZIxEks#URC*3$Z_4-hYCQSo(80)ZgGwyxekcgc@1C)xX(J$v?< z`Mw#R(b(?;$t{0k&d5@;L`;rAEX&SS2*e^eAyQ$!K*Gq<6&Y+>EK|3QWVup73L})t ziAJ+uhovB!kqIS=0?J6v%+6&&8M(c&(pVs8ggL1hjwFPAh*7oO8*GT7h*9-B>bMO_ zAhw-P0!Es`h(uypx?m0?PR~#fqpr9pEk;?n3N~tc^Iq1 zMkbIegbGGZ*=zk;AZ3IyM>35O<)$)`Y}?(KaE1AdJcv`lF=WyD41&A_kRoz6b)CWL zv291aV3CLnCmP>VBt@u5%@D}Na~Vhdm<(Yy<4%YkBNN%OUF(v`lwF#bO|d|({7EoJ ztjJ&i1u7>;-nc9oJ1&{R-G%7q2_?CVyB$x@%y6v9bM+rbmmAXYQw4`h!iYw9q;F)( zSff)%1Gbw~p`s%vV!5r3GO!Vvqad6(%2XnjvGc4RuR+J7qd}xl5t%C?ZltFHWgy8) z7c1oUhVra51zUTLhA|xt64(P;GZcz+WK)pI6xbWHAA%f4l^93EX&nuMeLI_UI;%2e zZ>sv&8z7iybV6BmcPLQg2xao@EIA``G@7n%yPFekHj^U_(h1TH;sog_>m`)*=!&w@ z?D)^aOd_V)Lb-x{l_Y7(5fkmfP{!G878sdAY13=6bd{4eh4k-6QjMGR6%%EuMpM~o z#iLBrZ=PJiWV?T=&$Hz5%s7to^zSGB4#N#nu+?RDz{KP?;L++Qa3w1VbKE>|?L5R# zuTuU>>mgWw%nA;>1mb5*9^$r<*bwH8LHRTmes6^VIwc_bYBX-H8;$Yi6Y!(`S>71M!LFOx$9JAgShJlaPy7ON*dvK^_*xIH29oNGzvj!o|nASb28@ zo(ri0o%7*1%y=_+KQ6-AB~Rg5Y%Z1=harFc7)a&a0AJVN`6F}o!G6Q*^aI^$*!S_n zz~i;HcvMmj6-KT&c6bzCaPdVsHw_0_`ryAUc7oq!5oR8`1&7)=^wVWiv8gc-JzB~k ztnLoYJa&Ntm1bz(MB${9tLWM-#^}Jc!pnM=INoA6tV?UAPwpR%bwLfV`M`f*<&j_D zi^xHEi}@Ia*k)jWNgF6Q`Oxgt4{r*i@!EYOjNe=W)#7;&Qe_8I$#vN9vk2q1^uwnc zoFRO46kf%RkY+IfoeG~okh2rETtIwwLW=E$X;{Ad4pc3$!6mKbV2jgnX910O=NRBI zNWcT8zWA^$5n5JkgEn{{o`W-Ow56E#NxuNDp*Gk*eKkGlx&_L1i{N1x1I6n-;F@1C zm@jPr^L3`!k}AZJ#}q^@JTZ5fKb8w=u-cuB7O7Y1Q*$RG1e>C5lNau~ z8;Zl17^Bbkm!R}yAzr&19e7rs3&-D+U{bw4eRJV?*eUnLt8fgwS90h`Qy$Kmv7Mgv z<3OxGl7Yh}`QoX9cq}LUcO;v*E+=RNDX)Jsg9TgX6Gn z=1e?uehE#lJ`Qc=_u!hh8)gPmxHPN?EJNd=g*yiyI84MxK~nTLipDWz)6s58KP>WT z0(tO7`r~je9y#_1;?C}XkkqSC+CKp6PfKx^yd2ISn1UX&Lh*8-6>f7#!@9$9IJn#n zgK;uu?MlFa(Ej*A?)&_P2^Zkr)oiFPUIJUcvc@)tDt=2#2;Tj~0gE=pLY%l3lD_*1 z9#^e{!)K}>DKP>nH%-Bt1qHYzb24sp{Q+iDGobuzC_caIMNdmtVDs8g{AZ09%=ym- z&|7K=4QI?i*LgYAlmPgat)NqX^o9?s7eY>LG=$qlpyM}{kkLL3^25hrMKOb0SA;;^ z4+(fOAQR)qCStn_A8OVn;bpHvXq`M7b%jUZESvYmt8T%;vS3_(J`KkmEe?#Un24(q zM`E?%96W8&AKMcwaQxfX==)brfcfK9P{YrInfK?SlhqVDc})y7HrxQp@;5l8>kRR* zm|pnHB&dpWfhX_GhohcWxFb0boluIUo{3mK+ZRuMYYlsa5_DKr3q^_fXk40%2SNwo z^SnBmGd2cy9C-uVx836ps

Q#SviiYaZTNl#fC0TEiE|-@=WD3(@WICTLh(gcXYh zVc-WhVE;FxaA2(iM%(ye^F|68@+tjvQ#2+`cSLwtgwm=sn9kHfP1IP7t$D*aIP)xP(AWjx&OI@C>M7#6*)05L?Q+W1VGpfDUri$;lEX5oXo4;k$$SJFEuAiE?4f5FGuQgblik+Vy1zl+8tjND}_%w=?~Rsp0!Us zn;y%Q2}E7!c*(xcByA|ojvO`G$`(sviHcRzYG}3*=U}D%=%A@`P}nyk z$m-_FT?Duf${*f0j!kbIRi(d#{cA6c?uzNuYs_rrV*G++Q<3dJic(7E3`QuD3xqOJ zhqmi3`^_d{bq9>Rtz?Z{Omg-7T3pcS$tZ-l3E!=*4?;1T|R~^jX?Sje9 zMff#$FbX%kgUxrR(B@w^!=*c9Fl&!QvwmtvjpTzC!5Ex8U?>XYzvMF~vsh{ww( zPJyUV2Ul$o!TjVkux73k9vm2gjYD*BM{FD}S`d!SPo-!%VGRD1o=5+Duo8Ul=b_#s zCtTAOj4nIHcze|f$iI+@Wke5!cJ=&thiE9I4uQj-D!S>RKbrEqAwPH|dK~k|qtd-# zKzz_^sX5L#mj|ndwm=9NLb+)$etG6O80_*xJ3l?lkvij=?d7m>?|8gCE(-l;w9{d~ zt$`7jqVT%;9+>P|h)u5h=$HYg;CKoTOgHpLeQF#;d+I^J0zWX`G(#IORJDezGM3hBXts7jzbu1P;Pk`**|k#xZEcX{F7snB%#_yWqBCD_jf|Vfc6R zq4o3*X!?CT_^dWS-u*aiUA+(H-S$GyE6eGj$qBSozsIzXQz}&7^umN)5ja~h5@X-K z3Hz#3pmg|jEZ-4@d^w~B9Og(d_@g2`>voX#a@#{(-9Q`RDzx~k((>=G+ z3Y%%PA-@ESO7Furv3@vO@{I3a6N}FO;$mZC3|42i!19V3u;W)j*}ILf_3d1gOHx_! z|2ZAfYukQ(tnm6+;q|e?>tlu2#|nStSYe!&tDV1mDDd|kFL)Ad8i`o5ryW!;y~;@$ z+9itz5PMeBTKBe53E;o;)FNdJ(Lx+42Qr8az`j=ix@-U(m3sCkS~=pM@Qfh6`QR%b zB}5W!eDGD_luF)$m@rxL7xcS1dnw4?7)k_+f^0^0H>I@87K*y Date: Tue, 26 Nov 2024 10:47:55 +0000 Subject: [PATCH 2/8] Fixing test --- .../test/contrib_ops/beam_search_test.cc | 6 ++++-- .../tiny_t5_with_sequence_input_ids.onnx | Bin 7633 -> 6754 bytes 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 71906c24236e4..29dee2daeac61 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -398,8 +398,10 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { TEST(BeamSearchTest, T5WithSequenceInputIds) { ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/tiny_t5_with_sequence_input_ids.onnx")); - tester.AddInput("encoder_input_ids", {1, 5}, {16, 9, 16, 14, 15}); - tester.AddOutput("sequences", {1, 3, 10}, {2, 1, 6, 1, 6, 1, 6, 1, 6, 1, 2, 1, 6, 1, 6, 1, 6, 1, 1, 6, 2, 1, 6, 1, 6, 1, 1, 6, 1, 6}); + tester.AddInput("encoder_input_ids", {1, 5}, {6, 4, 12, 5, 1}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 18, 3, 8, 8, 8, 8, 8, 8, 8, + 2, 18, 17, 12, 18, 3, 8, 8, 8, 8, + 2, 18, 17, 12, 18, 17, 12, 18, 3, 8}); #ifdef USE_CUDA tester.ConfigEp(DefaultCudaExecutionProvider()); #endif diff --git a/onnxruntime/test/testdata/tiny_t5_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/tiny_t5_with_sequence_input_ids.onnx index f168df3246baf9c41eee656539d9d7d563e1d29d..5a5c302914890e5f1c381eae3676c985e85cc0c6 100644 GIT binary patch literal 6754 zcmds6d0bOR*M_i#D+pXxH$;4EAu1w*2$DGgcLWy%t5{3}2?9a_Nzfv#qT+%Z6$PuJ zEUs8;QPEZ*bE0Auwc=K_E|gZO`%)EAtM$EX6$SKt`}U9T_vN3Pd+(WN&YXMB%=66f z8SgT8CdZAjkjWzyQkha5EssxBi=(9~##AknizO;CcTp*njHyB{Yj$g$a@1N|unwbtqG@hN=jKQl@fb{A7~2aWaWAB1&-AmNAmb>fLdx zv}FjxHr|P4ap5wlG+I7I#TftPC~MvHIx1Q!<$PADC29`RaF|3LB~vm1jCIrC-Y{T0 zw)$8cCzHq>nIM@oF+vu|{TH-j^d+ClR3hD6T_YxtX}vq<9cZj&Q%_BvsuJTI^?apL z#;XlD*0jkqT%wMMax@qhC5e{_7)DPfm)dcj@R$VVO&i8pC5w%07`-@L7Rkj`tX4|o z@d}kptg34+mMg^daIC}oOdSmeDdZ6nb;IDJRdzacUG`;K8wtZiMsuH79IsTwu!gaU zDbZ>!`hgO4U}7xuaqCUq>j2xaZ5JJlf@Jjr;=aZ^%vcK*mnyN`;OpWtRhA@nbuj#NLQp+_i;zR6V9?Z+%!pSqU^nZfs9$aM5S)#0n^ecmW}61 zKqgn$`&jR+Ym4~v>f%3+v3L#F=mFE3TdY+RWP{Z^K-s2R6Rb1fTk-kkx_+(Hn}3SI zhRd_)9jzP-&X=By*?eUofg)3tic_c zY>ns4a9lXOOc>lb2%S% zVpJ;47JzBLfKuF*2YFsU!h@1b+!c@m36ppFXf87FT&a&vrkn8R?sMq=M_Sx@drst2ZaoeWF-Ea3Kl9GYEYA$)Q^ zS<9>K459APLUX7U9@m>olYR3vYeUxH@c2?;)>9pbbnXrEeFdbSPClu6_6ONM$_LhY z+#<(x4pR3{KTz`#M`?!LN&4N^&+yliF_>?*oWvB&Be7d|({1)!$)$7`*m`XfIeRMw zt;2Lh-X3X~z4AO{Z~u+f^fJb<9tPUat&V}v=yz$R7&eZw^YWDry4=(VYl9>;Np>c#t z_$f1z9+^=>Om9s>lk^<$U$C0&sa!(?LUo8>##AcEG~o03R$P8In{T+hjN2oiBy2Wz zw*8*Woo6AkC|(MMUHZ{3;73<&PJyhAQRKG%eDHMXfXbsE!ymd|g1h@tw0JZWj|?#r z1qVlw8=H(p)AQ`L#nDS(t7|OjyK5QIn+4GBkP9*CxC3sTs1}BpZGvCZlCZXX06Z(} zL;F|0#7;3zaF6dHGE`O5g2=hxRy2zIl;w^+M;*cqA#32i)<7$%@r02Xw=tmS1=8+p z2ySuc2Xk!#P`kPvZZ5AxW&2!2<2Ay48;Wt&JSU>)aTr&wmsm{C{ca)IYQuupEd@E&x2MqY^IYgzuv>GZ zj~*tP-4=G+{}KBBmW;aPuIN-P(9U#Rik%&b;AQS)=ox4Z6=QBv*Rx@8YSBtaxK~NP z@U+0Ym(##r-W_5NOd<=-R^kuawn5RuM|AS(VRW;1Uy)(JDm1@hO#S*5mad(!3@qoU zfaE{*dBkCFH|Ymgx}}3s=1+|a9iUhH{d8rv9Tt?XFP&y@sP!D=OXU$>S6_d-SB*cOPQy>X@ftd?Ij?bHdhrce zC&q+j>Q`;CDtd;Dv14r-r~$XAG_WXf1JMc{&e$~Z42ha!ZR%)+TpTA+O=a6eYOrsC zG_ozqQehR5f%Jvt6bKXha(>)+o7^JmQK z7t;u(LZxb;2k*vebBy|-0Swy=t8r!iyJ+D)Oe2e8+l+VXDZmBEDm<9;r3Ig03FE_s zd_XH9cVO+C?At`u4OcMSb1WcI7%jws>Bkxw7#5QOM?auT~BLUkiUBJ-3;~(9rK!CH-lj9 zo1*cCt(Z%UIwc@|O`*pz?=fX9o6@6>LN`>)`-pVNdqJ(6KpTqXKZ9-KBR5H<{~wa@ z|D4Dg^-fjP%4G8%l{My_c&$z|??z;sCW2-XrYZejZy9h@ajv=$Q*t`M&M+Uk{p3vW zyJ7^)sezE3by;{uEQS1{fw;qQA30mJkH$?3qKA{JsNP;ne7VsbD-)a$M;6n?jzdZJ zsip9H)D3zqZ5c_#A#@&e1X9ap(n+ThF>6f)tQgWx zl;!I~`VYK{{fvf_bDOr&;;I28_}m7nR`5kSv+j{&-Wt>_+(g3X42K1KKBS@;ClUQ> zsir2+km!~r;s^ZIWXSIWaKYId`0C3Pm~8tQ^*i-a_$2)Rsky!g2Re<%yRK6)?4A=G zd~8BBff@+3nhZq`Sg?9XVACEq?Bx@HCr9LIto$cH*iL;=_C;dgXN*q+yf+Detgbya(2Q~4M8d|98gnKS70=?-W z_$j$WFFpSRJ&Xb{CF>G7qbY*iL-Xm~OG$Y3lUKODA_rHy`crwu7u4ui1TXtyN*bVwpSv-`VKLX#lPND88a`Q;eg2~VRq=G|b!k?;U!}azHa7+4H zG}(FxUd&D*yQn3Y9!rCNSxuq0&u*p3HD99dw<(acek0y4oq)Po;ppGvTk_(jKK#ow z3Hz75Bxh%;i9o>Taa6IDb->^JZ7$fB+_}w;!A&6_YR|3fjbS%Z;eIh6U!Wl!?63`U zDnj6w+mYnR$J>Cpv%))2=OL|KZ9_7Hd%+Q{H_^Kl31POsK~UN@@XR_!syr5uSrg0f z!Q?{Nc0H@KytWX#I%g4wYcB9)#aem*Y@kDBjiyUUS5aKr2O_N90t+Vzi1O1+I1qb+ z+Rd9wA9=l?USr2&nr#Z|l?;VH0~f>3S3bd-UB;rk_)++!+dQ-rPr^B;?0w35bjE~- zduefS2`w(UOQP00(17#L>4TBWaZr{X=50Iyfgk09?c5^Lb^cDhev zaY`lJ%nJgSqC0pqwuE@P@5QGBcEgQzqcxi!wa10!3$SE@nMiwN35G_j#-bTJaGG~9 z*?aE<{d)a6e4aP~Z;CEJ;N3DP$WPK_pVrg#nU@b;%{Ie@Awe)<{sD4o!)Rim9F5

evRX(`*+tuU?2_NRR&j>25C zfv`@yiWWqvY4uDq*yj-jQoqHx?6Cy$stTZc5+m}?xCoBbyM59e7Lom#qj7jn3T%zv zr1A9A1Ho{w(5=%`;RD(r_3v*P$h5`^|F+^8&9u%3|1Myi ztgvKEWD4&1;terj!7y>$IY+EmofI#t-$$>riu9MWQq3CF zHxk&I8~VL@EaI=vaM|7P;*A3mfBm|&7j)oSynR&Sudi44%8}d2YV(Z5U;l^WOW+@= zkZ@p(o3;v6^@`H4g@Wht_1c!+;f_q>F%4T+4Z5vVgPtE32XhNQ3+Bzhnd|-=8-F7c literal 7633 zcmeHMYgALm7A6oN44`m$Qy<7f1VK>ZIxEks#URC*3$Z_4-hYCQSo(80)ZgGwyxekcgc@1C)xX(J$v?< z`Mw#R(b(?;$t{0k&d5@;L`;rAEX&SS2*e^eAyQ$!K*Gq<6&Y+>EK|3QWVup73L})t ziAJ+uhovB!kqIS=0?J6v%+6&&8M(c&(pVs8ggL1hjwFPAh*7oO8*GT7h*9-B>bMO_ zAhw-P0!Es`h(uypx?m0?PR~#fqpr9pEk;?n3N~tc^Iq1 zMkbIegbGGZ*=zk;AZ3IyM>35O<)$)`Y}?(KaE1AdJcv`lF=WyD41&A_kRoz6b)CWL zv291aV3CLnCmP>VBt@u5%@D}Na~Vhdm<(Yy<4%YkBNN%OUF(v`lwF#bO|d|({7EoJ ztjJ&i1u7>;-nc9oJ1&{R-G%7q2_?CVyB$x@%y6v9bM+rbmmAXYQw4`h!iYw9q;F)( zSff)%1Gbw~p`s%vV!5r3GO!Vvqad6(%2XnjvGc4RuR+J7qd}xl5t%C?ZltFHWgy8) z7c1oUhVra51zUTLhA|xt64(P;GZcz+WK)pI6xbWHAA%f4l^93EX&nuMeLI_UI;%2e zZ>sv&8z7iybV6BmcPLQg2xao@EIA``G@7n%yPFekHj^U_(h1TH;sog_>m`)*=!&w@ z?D)^aOd_V)Lb-x{l_Y7(5fkmfP{!G878sdAY13=6bd{4eh4k-6QjMGR6%%EuMpM~o z#iLBrZ=PJiWV?T=&$Hz5%s7to^zSGB4#N#nu+?RDz{KP?;L++Qa3w1VbKE>|?L5R# zuTuU>>mgWw%nA;>1mb5*9^$r<*bwH8LHRTmes6^VIwc_bYBX-H8;$Yi6Y!(`S>71M!LFOx$9JAgShJlaPy7ON*dvK^_*xIH29oNGzvj!o|nASb28@ zo(ri0o%7*1%y=_+KQ6-AB~Rg5Y%Z1=harFc7)a&a0AJVN`6F}o!G6Q*^aI^$*!S_n zz~i;HcvMmj6-KT&c6bzCaPdVsHw_0_`ryAUc7oq!5oR8`1&7)=^wVWiv8gc-JzB~k ztnLoYJa&Ntm1bz(MB${9tLWM-#^}Jc!pnM=INoA6tV?UAPwpR%bwLfV`M`f*<&j_D zi^xHEi}@Ia*k)jWNgF6Q`Oxgt4{r*i@!EYOjNe=W)#7;&Qe_8I$#vN9vk2q1^uwnc zoFRO46kf%RkY+IfoeG~okh2rETtIwwLW=E$X;{Ad4pc3$!6mKbV2jgnX910O=NRBI zNWcT8zWA^$5n5JkgEn{{o`W-Ow56E#NxuNDp*Gk*eKkGlx&_L1i{N1x1I6n-;F@1C zm@jPr^L3`!k}AZJ#}q^@JTZ5fKb8w=u-cuB7O7Y1Q*$RG1e>C5lNau~ z8;Zl17^Bbkm!R}yAzr&19e7rs3&-D+U{bw4eRJV?*eUnLt8fgwS90h`Qy$Kmv7Mgv z<3OxGl7Yh}`QoX9cq}LUcO;v*E+=RNDX)Jsg9TgX6Gn z=1e?uehE#lJ`Qc=_u!hh8)gPmxHPN?EJNd=g*yiyI84MxK~nTLipDWz)6s58KP>WT z0(tO7`r~je9y#_1;?C}XkkqSC+CKp6PfKx^yd2ISn1UX&Lh*8-6>f7#!@9$9IJn#n zgK;uu?MlFa(Ej*A?)&_P2^Zkr)oiFPUIJUcvc@)tDt=2#2;Tj~0gE=pLY%l3lD_*1 z9#^e{!)K}>DKP>nH%-Bt1qHYzb24sp{Q+iDGobuzC_caIMNdmtVDs8g{AZ09%=ym- z&|7K=4QI?i*LgYAlmPgat)NqX^o9?s7eY>LG=$qlpyM}{kkLL3^25hrMKOb0SA;;^ z4+(fOAQR)qCStn_A8OVn;bpHvXq`M7b%jUZESvYmt8T%;vS3_(J`KkmEe?#Un24(q zM`E?%96W8&AKMcwaQxfX==)brfcfK9P{YrInfK?SlhqVDc})y7HrxQp@;5l8>kRR* zm|pnHB&dpWfhX_GhohcWxFb0boluIUo{3mK+ZRuMYYlsa5_DKr3q^_fXk40%2SNwo z^SnBmGd2cy9C-uVx836ps

Q#SviiYaZTNl#fC0TEiE|-@=WD3(@WICTLh(gcXYh zVc-WhVE;FxaA2(iM%(ye^F|68@+tjvQ#2+`cSLwtgwm=sn9kHfP1IP7t$D*aIP)xP(AWjx&OI@C>M7#6*)05L?Q+W1VGpfDUri$;lEX5oXo4;k$$SJFEuAiE?4f5FGuQgblik+Vy1zl+8tjND}_%w=?~Rsp0!Us zn;y%Q2}E7!c*(xcByA|ojvO`G$`(sviHcRzYG}3*=U}D%=%A@`P}nyk z$m-_FT?Duf${*f0j!kbIRi(d#{cA6c?uzNuYs_rrV*G++Q<3dJic(7E3`QuD3xqOJ zhqmi3`^_d{bq9>Rtz?Z{Omg-7T3pcS$tZ-l3E!=*4?;1T|R~^jX?Sje9 zMff#$FbX%kgUxrR(B@w^!=*c9Fl&!QvwmtvjpTzC!5Ex8U?>XYzvMF~vsh{ww( zPJyUV2Ul$o!TjVkux73k9vm2gjYD*BM{FD}S`d!SPo-!%VGRD1o=5+Duo8Ul=b_#s zCtTAOj4nIHcze|f$iI+@Wke5!cJ=&thiE9I4uQj-D!S>RKbrEqAwPH|dK~k|qtd-# zKzz_^sX5L#mj|ndwm=9NLb+)$etG6O80_*xJ3l?lkvij=?d7m>?|8gCE(-l;w9{d~ zt$`7jqVT%;9+>P|h)u5h=$HYg;CKoTOgHpLeQF#;d+I^J0zWX`G(#IORJDezGM3hBXts7jzbu1P;Pk`**|k#xZEcX{F7snB%#_yWqBCD_jf|Vfc6R zq4o3*X!?CT_^dWS-u*aiUA+(H-S$GyE6eGj$qBSozsIzXQz}&7^umN)5ja~h5@X-K z3Hz#3pmg|jEZ-4@d^w~B9Og(d_@g2`>voX#a@#{(-9Q`RDzx~k((>=G+ z3Y%%PA-@ESO7Furv3@vO@{I3a6N}FO;$mZC3|42i!19V3u;W)j*}ILf_3d1gOHx_! z|2ZAfYukQ(tnm6+;q|e?>tlu2#|nStSYe!&tDV1mDDd|kFL)Ad8i`o5ryW!;y~;@$ z+9itz5PMeBTKBe53E;o;)FNdJ(Lx+42Qr8az`j=ix@-U(m3sCkS~=pM@Qfh6`QR%b zB}5W!eDGD_luF)$m@rxL7xcS1dnw4?7)k_+f^0^0H>I@87K*y Date: Tue, 26 Nov 2024 10:56:36 +0000 Subject: [PATCH 3/8] Fix linting --- onnxruntime/test/contrib_ops/beam_search_test.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 29dee2daeac61..9c04b54bda77d 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -399,9 +399,7 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { TEST(BeamSearchTest, T5WithSequenceInputIds) { ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/tiny_t5_with_sequence_input_ids.onnx")); tester.AddInput("encoder_input_ids", {1, 5}, {6, 4, 12, 5, 1}); - tester.AddOutput("sequences", {1, 3, 10}, {2, 18, 3, 8, 8, 8, 8, 8, 8, 8, - 2, 18, 17, 12, 18, 3, 8, 8, 8, 8, - 2, 18, 17, 12, 18, 17, 12, 18, 3, 8}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 18, 3, 8, 8, 8, 8, 8, 8, 8, 2, 18, 17, 12, 18, 3, 8, 8, 8, 8, 2, 18, 17, 12, 18, 17, 12, 18, 3, 8}); #ifdef USE_CUDA tester.ConfigEp(DefaultCudaExecutionProvider()); #endif From c55257fb076db5f0ff340770aa701774b14b5090 Mon Sep 17 00:00:00 2001 From: amancini-N Date: Thu, 28 Nov 2024 18:38:00 +0000 Subject: [PATCH 4/8] Changes: - Fixing initialization of encoder_hidden_states feed in T5 decoder - Re-generating data for T5 BeamSearch test --- .../cpu/transformers/subgraph_t5_decoder.cc | 12 +++++++----- onnxruntime/test/contrib_ops/beam_search_test.cc | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index fc638a84fb972..9094adf484ee2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -172,8 +172,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size * sizeof(int)); - auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size, false, stream); + size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size); + size_t total_size_bytes = total_size * sizeof(int); + auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); if (!use_sequence_as_input_ids_) { @@ -197,11 +198,12 @@ Status T5DecoderSubgraph::CreateInitialFeeds( DeviceCopyDirection::deviceToDevice)); } } else { + const size_t cur_len_bytes = cur_len * sizeof(int); for (int i = 0; i < batch_beam_size; i++) { gsl::span sequence = sequences.GetSequence(i); const int32_t* sequence_data = sequence.data(); long long seq_index = (long long)i * cur_len; - memcpy(seq_copy_ptr + seq_index, sequence_data, total_size); + memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes); } gsl::span temp_input(input_ids_data, total_size); gsl::span temp_sequence(seq_copy_ptr, total_size); @@ -246,7 +248,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } else { ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, @@ -254,7 +256,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( num_beam, allocator, expanded_hidden_states, - true, + false, 0 /*max_sequence_length*/)); } decoder_feeds.push_back(expanded_hidden_states); diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 9c04b54bda77d..cb2abd09bd418 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -398,8 +398,8 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { TEST(BeamSearchTest, T5WithSequenceInputIds) { ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/tiny_t5_with_sequence_input_ids.onnx")); - tester.AddInput("encoder_input_ids", {1, 5}, {6, 4, 12, 5, 1}); - tester.AddOutput("sequences", {1, 3, 10}, {2, 18, 3, 8, 8, 8, 8, 8, 8, 8, 2, 18, 17, 12, 18, 3, 8, 8, 8, 8, 2, 18, 17, 12, 18, 17, 12, 18, 3, 8}); + tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); + tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13}); #ifdef USE_CUDA tester.ConfigEp(DefaultCudaExecutionProvider()); #endif From f38c9d33fb879c94f8f6cefc59274c8031a5752d Mon Sep 17 00:00:00 2001 From: amancini-N Date: Tue, 3 Dec 2024 10:54:03 +0000 Subject: [PATCH 5/8] Fix failing variable type on some builds --- onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 9094adf484ee2..54d4a731f2ad6 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -187,7 +187,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( if (use_cuda) { auto sequences_buffer = sequences.GetCurrentDeviceSequences(); for (int i = 0; i < batch_beam_size; i++) { - long long batch_beam_stride = (long long)i * sequences.GetMaxLength(); + size_t batch_beam_stride = (size_t)i * sequences.GetMaxLength(); int seq_size = sequences.GetSequenceLength(); gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); gsl::span temp_input(input_ids_data + i * seq_size, seq_size); From 096166912d4239c491cf899b983b631a46e3df8e Mon Sep 17 00:00:00 2001 From: amancini-N Date: Fri, 6 Dec 2024 11:44:14 +0000 Subject: [PATCH 6/8] Ensure test is executed on cpu and cuda EPs --- onnxruntime/test/contrib_ops/beam_search_test.cc | 5 +++-- ...s.onnx => dummy_t5_with_sequence_input_ids.onnx} | Bin 2 files changed, 3 insertions(+), 2 deletions(-) rename onnxruntime/test/testdata/{tiny_t5_with_sequence_input_ids.onnx => dummy_t5_with_sequence_input_ids.onnx} (100%) diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index cb2abd09bd418..9f0dd909c18be 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -396,8 +396,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { } } -TEST(BeamSearchTest, T5WithSequenceInputIds) { - ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/tiny_t5_with_sequence_input_ids.onnx")); +TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { + ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); + tester.ConfigEp(DefaultCpuExecutionProvider()); tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13}); #ifdef USE_CUDA diff --git a/onnxruntime/test/testdata/tiny_t5_with_sequence_input_ids.onnx b/onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx similarity index 100% rename from onnxruntime/test/testdata/tiny_t5_with_sequence_input_ids.onnx rename to onnxruntime/test/testdata/dummy_t5_with_sequence_input_ids.onnx From a9d5f7d40023ee4c508dba395b1f9d7131b27f84 Mon Sep 17 00:00:00 2001 From: amancini-N Date: Mon, 9 Dec 2024 10:34:44 +0000 Subject: [PATCH 7/8] Skip tests on DML --- onnxruntime/test/contrib_ops/beam_search_test.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/test/contrib_ops/beam_search_test.cc b/onnxruntime/test/contrib_ops/beam_search_test.cc index 9f0dd909c18be..76c02c565f632 100644 --- a/onnxruntime/test/contrib_ops/beam_search_test.cc +++ b/onnxruntime/test/contrib_ops/beam_search_test.cc @@ -397,6 +397,9 @@ TEST(BeamSearchTest, GptBeamSearchFp16_VocabPadded) { } TEST(BeamSearchTest, DummyT5WithSequenceInputIds) { +#if defined(USE_CUDA) && defined(USE_DML) + SKIP_CUDA_TEST_WITH_DML; +#endif ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx")); tester.ConfigEp(DefaultCpuExecutionProvider()); tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8}); From 974ed53e3a96311261390764198e43d86c54e7df Mon Sep 17 00:00:00 2001 From: amancini-N Date: Tue, 10 Dec 2024 10:49:24 +0000 Subject: [PATCH 8/8] Applying suggestions --- .../contrib_ops/cpu/transformers/subgraph_t5_decoder.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 0feab9aaacd9a..f4e7173c917c1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -172,7 +172,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); AllocatorPtr buffer_allocator = std::make_shared(); - size_t total_size = static_cast(static_cast(cur_len) * batch_beam_size); + size_t total_size = static_cast(cur_len) * static_cast(batch_beam_size); size_t total_size_bytes = total_size * sizeof(int); auto seq_copy = IAllocator::MakeUniquePtr(buffer_allocator, total_size_bytes, false, stream); int* seq_copy_ptr = seq_copy.get(); @@ -187,10 +187,10 @@ Status T5DecoderSubgraph::CreateInitialFeeds( if (use_cuda) { auto sequences_buffer = sequences.GetCurrentDeviceSequences(); for (int i = 0; i < batch_beam_size; i++) { - size_t batch_beam_stride = (size_t)i * sequences.GetMaxLength(); + size_t batch_beam_stride = static_cast(i) * static_cast(sequences.GetMaxLength()); int seq_size = sequences.GetSequenceLength(); gsl::span sequence = sequences_buffer.subspan(batch_beam_stride, seq_size); - gsl::span temp_input(input_ids_data + i * seq_size, seq_size); + gsl::span temp_input(input_ids_data + static_cast(i) * seq_size, seq_size); ORT_RETURN_IF_ERROR(device_copy_int32_func( temp_input, sequence, @@ -202,7 +202,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( for (int i = 0; i < batch_beam_size; i++) { gsl::span sequence = sequences.GetSequence(i); const int32_t* sequence_data = sequence.data(); - long long seq_index = (long long)i * cur_len; + ptrdiff_t seq_index = static_cast(i) * cur_len; memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes); } gsl::span temp_input(input_ids_data, total_size);