Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA EP] Fix BeamSearch on T5 with sequence_as_input_ids (#20667) #20668

Merged
merged 10 commits into from
Dec 11, 2024
26 changes: 17 additions & 9 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
current_length,
cpu_state.sequences,
parameters->max_length,
decoder_subgraph_.has_decoder_masked_attention_));
decoder_subgraph_.has_decoder_masked_attention_,
this->cuda_device_prop_ != nullptr));

if (decoder_subgraph_.past_present_share_buffer_) {
decoder_fetches.reserve(static_cast<size_t>(decoder_subgraph_.GetFirstPresentOutputIndex()) +
Expand Down Expand Up @@ -302,17 +303,24 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
auto cur_len = std::to_string(current_length);
dumper->Print("***CurrentLength", cur_len, true);

for (int i = 0; i <= decoder_subgraph_.GetFirstPastInputIndex(); i++) {
for (int i = 0; i < decoder_subgraph_.GetFirstPastInputIndex(); i++) {
dumper->Print("decoder_feeds", i, true);
dumper->Print("", decoder_feeds[i]);
}
auto offset = decoder_subgraph_.GetFirstPastInputIndex() + 4 * decoder_subgraph_.num_layers;
dumper->Print("past_sequence_length", offset, true);
dumper->Print("", decoder_feeds[offset]);
dumper->Print("beam_width", offset + 1, true);
dumper->Print("", decoder_feeds[offset + 1]);
dumper->Print("cache_redir", offset + 2, true);
dumper->Print("", decoder_feeds[offset + 2]);
for (int i = 0; i < decoder_subgraph_.num_layers; i++) {
int self_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * i;
int self_value_idx = self_key_idx + 1;
dumper->Print("past_key_self", i, true);
dumper->Print("", decoder_feeds[self_key_idx]);
dumper->Print("past_value_self", i + 1, true);
dumper->Print("", decoder_feeds[self_value_idx]);
int cross_key_idx = decoder_subgraph_.GetFirstPastInputIndex() + 2 * decoder_subgraph_.num_layers + 2 * i;
int cross_value_idx = cross_key_idx + 1;
dumper->Print("past_key_cross", i, true);
dumper->Print("", decoder_feeds[cross_key_idx]);
dumper->Print("past_value_cross", i, true);
dumper->Print("", decoder_feeds[cross_value_idx]);
}
#endif

#ifdef DEBUG_NODE_INPUTS_OUTPUTS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ struct ISequences {
virtual gsl::span<const int32_t> GetCurrentDeviceSequences() const = 0; // Get all current beam_index sequences in one continuous block (to pass to CUDA)
virtual gsl::span<int32_t> GetNextDeviceSequences() = 0; // Get all next beam_index sequences in one continuous block (to pass to CUDA)
virtual int GetSequenceLength() const = 0;
virtual int GetMaxLength() const = 0;
};

struct ILogitsProcessorList {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/sequences.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ int Sequences::GetSequenceLength() const {
return current_length_;
}

int Sequences::GetMaxLength() const {
return max_length_;
}

#ifdef DEBUG_GENERATION
void Sequences::PrintSequences(const IConsoleDumper* dumper) const {
for (int i = 0; i < batch_beam_size_; i++) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/sequences.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class Sequences : public ISequences {
// Returns current sequence length.
int GetSequenceLength() const override;

// Returns max sequence length.
int GetMaxLength() const override;

#ifdef DEBUG_GENERATION
// Print the sequences to StdOut in debug mode
void PrintSequences(const IConsoleDumper* dumper) const;
Expand Down
52 changes: 35 additions & 17 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
int cur_len,
transformers::Sequences& sequences,
int past_present_share_buffer_max_seq_len,
bool need_cache_indir) {
bool need_cache_indir,
bool use_cuda) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");

// Allocate subgraph inputs from same device as inputs of encoder subgraph.
Expand All @@ -171,8 +172,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
AllocatorPtr buffer_allocator = std::make_shared<onnxruntime::CPUAllocator>();
size_t total_size = static_cast<size_t>(static_cast<long long>(cur_len) * batch_beam_size * sizeof(int));
auto seq_copy = IAllocator::MakeUniquePtr<int>(buffer_allocator, total_size, false, stream);
size_t total_size = static_cast<size_t>(cur_len) * static_cast<size_t>(batch_beam_size);
size_t total_size_bytes = total_size * sizeof(int);
auto seq_copy = IAllocator::MakeUniquePtr<int>(buffer_allocator, total_size_bytes, false, stream);
int* seq_copy_ptr = seq_copy.get();

if (!use_sequence_as_input_ids_) {
Expand All @@ -182,19 +184,35 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
stream,
DeviceCopyDirection::hostToDevice));
} else {
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const int32_t> sequence = sequences.GetSequence(i);
const int32_t* sequence_data = sequence.data();
long long seq_index = (long long)i * cur_len;
memcpy(seq_copy_ptr + seq_index, sequence_data, total_size);
if (use_cuda) {
auto sequences_buffer = sequences.GetCurrentDeviceSequences();
for (int i = 0; i < batch_beam_size; i++) {
size_t batch_beam_stride = static_cast<size_t>(i) * static_cast<size_t>(sequences.GetMaxLength());
int seq_size = sequences.GetSequenceLength();
gsl::span<const int32_t> sequence = sequences_buffer.subspan(batch_beam_stride, seq_size);
gsl::span<int> temp_input(input_ids_data + static_cast<ptrdiff_t>(i) * seq_size, seq_size);
ORT_RETURN_IF_ERROR(device_copy_int32_func(
temp_input,
sequence,
stream,
DeviceCopyDirection::deviceToDevice));
}
} else {
const size_t cur_len_bytes = cur_len * sizeof(int);
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const int32_t> sequence = sequences.GetSequence(i);
const int32_t* sequence_data = sequence.data();
ptrdiff_t seq_index = static_cast<ptrdiff_t>(i) * cur_len;
memcpy(seq_copy_ptr + seq_index, sequence_data, cur_len_bytes);
}
gsl::span<int> temp_input(input_ids_data, total_size);
gsl::span<int> temp_sequence(seq_copy_ptr, total_size);
ORT_RETURN_IF_ERROR(device_copy_int32_func(
temp_input,
temp_sequence,
stream,
DeviceCopyDirection::hostToDevice));
}
gsl::span<int> temp_input(input_ids_data, total_size);
gsl::span<int> temp_sequence(seq_copy_ptr, total_size);
ORT_RETURN_IF_ERROR(device_copy_int32_func(
temp_input,
temp_sequence,
stream,
DeviceCopyDirection::hostToDevice));
}

// The ordering is the same as used in Setup.
Expand Down Expand Up @@ -230,15 +248,15 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
num_beam,
allocator,
expanded_hidden_states,
true,
false,
0 /*max_sequence_length*/));
} else {
ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream,
encoder_fetches[j],
num_beam,
allocator,
expanded_hidden_states,
true,
false,
0 /*max_sequence_length*/));
}
decoder_feeds.push_back(expanded_hidden_states);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class T5DecoderSubgraph : public Subgraph {
int cur_len,
transformers::Sequences& sequences,
int past_present_share_buffer_max_seq_len = -1,
bool need_cache_indir = false);
bool need_cache_indir = false,
bool use_cuda = false);

Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1264,16 +1264,14 @@ Status UpdateDecoderFeeds(
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream));
} else {
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const int32_t> sequence = sequences.GetSequence(i);
const int32_t* sequence_data = sequence.data();
CUDA_RETURN_IF_ERROR(
cudaMemcpyAsync(input_ids_data + static_cast<ptrdiff_t>(i) * current_length,
sequence_data,
current_length * sizeof(int32_t),
cudaMemcpyHostToDevice,
cuda_stream));
}
// We expect sequences to point directly to device memory
int max_length = sequences.GetMaxLength();
auto sequences_buffer = sequences.GetCurrentDeviceSequences();
CUDA_RETURN_IF_ERROR(
cudaMemcpy2DAsync(input_ids_data, current_length * sizeof(int32_t),
sequences_buffer.data(), max_length * sizeof(int32_t),
current_length * sizeof(int32_t), batch_beam_size,
cudaMemcpyDeviceToDevice, cuda_stream));
}
next_inputs[0] = input_ids;

Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,5 +424,19 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) {
tester.RunWithConfig();
}

TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8});
tester.AddOutput("sequences", {1, 3, 10}, {2, 19, 18, 3, 8, 8, 8, 8, 8, 8, 2, 19, 18, 3, 10, 19, 18, 3, 8, 8, 2, 19, 18, 15, 13, 13, 13, 13, 13, 13});
#ifdef USE_CUDA
tester.ConfigEp(DefaultCudaExecutionProvider());
#endif
tester.RunWithConfig();
}

} // namespace test
} // namespace onnxruntime
Binary file not shown.
Loading