Skip to content

Commit

Permalink
[ CB ] Fix streaming in case of empty outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Jan 29, 2025
1 parent 5cbadd1 commit 7f71f12
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o

if (generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->read();
if (token.empty()) {
continue;
}
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
Expand Down
3 changes: 3 additions & 0 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
GenerationHandle& handle = generations.at(0);
if (streamer_ptr && handle->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = handle->back();
if (token.empty()) {
return;
}
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
handle->drop();
Expand Down
3 changes: 3 additions & 0 deletions src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector<ov::Ten

if (generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->read();
if (token.empty()) {
continue;
}
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
Expand Down
4 changes: 4 additions & 0 deletions src/cpp/src/sequence_group.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,11 @@ class SequenceGroup : public std::enable_shared_from_this<SequenceGroup> {
if (has_finished()) {
m_stream_window_size = 0;
}
// push empty output in case we won't stream generation res
if (generated_len <= (m_num_streamed_tokens + m_stream_window_size)) {
if (has_finished()) {
push_empty_outputs();
}
return;
}
// speculative decoding draft handling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,10 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<

if (generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->read();
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (token.empty()) {
continue;
}
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
break;
Expand Down
24 changes: 21 additions & 3 deletions tests/python_tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,9 @@ def test_preemption_with_multinomial_n_seq(tmp_path, dynamic_split_fuse):
scheduler_config = get_scheduler_config({"num_kv_blocks": 8, "dynamic_split_fuse": dynamic_split_fuse, "max_num_batched_tokens": 256, "max_num_seqs": 256})
generate_and_compare_with_reference_text(models_path, multinomial_params_n_seq.prompts, multinomial_params_n_seq.ref_texts, multinomial_params_n_seq.generation_config, scheduler_config)

def get_data_by_pipeline_type(model_path: Path, pipeline_type: str):
def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_config: GenerationConfig = GenerationConfig()):
device = "CPU"
prompt = "Prompt example is"
generation_config = GenerationConfig()
prompt = "Prompt example is"
generation_config.max_new_tokens = 10
generation_config.do_sample = True
pipe = None
Expand Down Expand Up @@ -380,3 +379,22 @@ def test_pipelines_generate_with_streaming(tmp_path, pipeline_type):

del pipe
rmtree(models_path)

@pytest.mark.parametrize("pipeline_type", ["continuous_batching", "speculative_decoding", "prompt_lookup_decoding", "llm_pipeline"])
@pytest.mark.precommit
def test_pipelines_generate_with_streaming_empty_output(tmp_path, pipeline_type):
model_id : str = "facebook/opt-125m"
opt_model, hf_tokenizer = get_hugging_face_models(model_id)

models_path : Path = tmp_path / "t_streaming" / model_id
convert_models(opt_model, hf_tokenizer, models_path)

generation_config = GenerationConfig()
generation_config.stop_strings = {" the "}

pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config)
py_streamer = lambda x: False
_ = pipe.generate(input, generation_config=gen_config, streamer=py_streamer)

del pipe
rmtree(models_path)

0 comments on commit 7f71f12

Please sign in to comment.