Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
iefode committed Jan 30, 2025
1 parent 7f71f12 commit f5d8256
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 35 deletions.
17 changes: 9 additions & 8 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,15 +465,16 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
});

if (generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->read();
if (token.empty()) {
continue;
}
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
break;
std::unordered_map<uint64_t, GenerationOutput> generation_outputs = generation->read();
OPENVINO_ASSERT(generation_outputs.size() <= 1);
for (const auto& generation_output : generation_outputs) {
for (const auto& generated_token_id : generation_output.second.generated_ids) {
if (streamer_ptr->put(generated_token_id)) {
generation->drop();
break;
}
}
break;
}
}
};
Expand Down
17 changes: 9 additions & 8 deletions src/cpp/src/lm_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,16 @@ std::pair<EncodedResults, std::optional<int64_t>> get_lm_encoded_results(
auto stream_generated_tokens = [&streamer_ptr, &generations, &active_sequence_groups]() {
GenerationHandle& handle = generations.at(0);
if (streamer_ptr && handle->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = handle->back();
if (token.empty()) {
return;
}
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
handle->drop();
break;
std::unordered_map<uint64_t, GenerationOutput> generation_outputs = handle->read();
OPENVINO_ASSERT(generation_outputs.size() <= 1);
for (const auto& generation_output : generation_outputs) {
for (const auto& generated_token_id : generation_output.second.generated_ids) {
if (streamer_ptr->put(generated_token_id)) {
handle->drop();
break;
}
}
break;
}
}
};
Expand Down
17 changes: 9 additions & 8 deletions src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,16 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vector<ov::Ten
});

if (generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->read();
if (token.empty()) {
continue;
}
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
break;
std::unordered_map<uint64_t, GenerationOutput> generation_outputs = generation->read();
OPENVINO_ASSERT(generation_outputs.size() <= 1);
for (const auto& generation_output : generation_outputs) {
for (const auto& generated_token_id : generation_output.second.generated_ids) {
if (streamer_ptr->put(generated_token_id)) {
generation->drop();
break;
}
}
break;
}
}
};
Expand Down
23 changes: 13 additions & 10 deletions src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
auto all_requests = get_awaiting_requests();

std::atomic<bool> has_active_requests = has_non_finished_requests();
auto& generation = main_generations.at(0);
GenerationHandle& generation = main_generations.at(0);

// create variables to make optimal thread-safe streaming
std::mutex mutex;
std::unique_lock lock(mutex);
std::unique_lock<std::mutex> lock(mutex);
std::condition_variable cv;

// auto t_stream_ptr = create_streaming_thread(streamer_ptr, lock, cv, main_generations.at(0), has_active_requests);

std::shared_ptr<std::thread> t_stream_ptr = nullptr;
if (streamer_ptr) {
// define stream token lambda to use in `t_stream_ptr`
Expand All @@ -276,15 +278,16 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector<
});

if (generation->can_read()) {
std::unordered_map<uint64_t, GenerationOutput> token = generation->read();
if (token.empty()) {
continue;
}
for (const auto& gen_token : token.begin()->second.generated_ids) {
if (streamer_ptr->put(gen_token)) {
generation->drop();
break;
std::unordered_map<uint64_t, GenerationOutput> generation_outputs = generation->read();
OPENVINO_ASSERT(generation_outputs.size() <= 1);
for (const auto& generation_output : generation_outputs) {
for (const auto& generated_token_id : generation_output.second.generated_ids) {
if (streamer_ptr->put(generated_token_id)) {
generation->drop();
break;
}
}
break;
}
}
};
Expand Down
3 changes: 2 additions & 1 deletion tests/python_tests/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ def test_pipelines_generate_with_streaming(tmp_path, pipeline_type):
models_path : Path = tmp_path / "t_streaming" / model_id
convert_models(opt_model, hf_tokenizer, models_path)

pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type)
generation_config = GenerationConfig()
pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config)
py_streamer = lambda x: False
_ = pipe.generate(input, generation_config=gen_config, streamer=py_streamer)

Expand Down

0 comments on commit f5d8256

Please sign in to comment.