Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel sampling with threadpool #1252

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,22 @@ jobs:
- name: run and compare
run: |
source ./ov/setupvars.sh
echo Running speculative_decoding_lm C++ sample...
./build/samples/cpp/text_generation/speculative_decoding_lm ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_speculative.txt
echo Running greedy_causal_lm C++ sample...
./build/samples/cpp/text_generation/greedy_causal_lm ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt
echo Running speculative_decoding_lm Python sample...
python ./samples/python/text_generation/speculative_decoding_lm.py ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_py.txt
echo All samples executed, checking result correctness...
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
with open('predictions_speculative.txt', 'r') as f:
predicted_speculative = f.readline()
with open('predictions_py.txt', 'r') as f:
predicted_py = f.readline()
print(f'Predicted greedy: {predicted_greedy}')
iefode marked this conversation as resolved.
Show resolved Hide resolved
print(f'Predicted speculative: {predicted_speculative}')
assert predicted_greedy == predicted_speculative
assert predicted_greedy == predicted_py
assert predicted_speculative == predicted_py
Expand Down Expand Up @@ -523,17 +529,23 @@ jobs:
```
Question: Can you please add 2 and 3
A:' > ./prompt.txt

echo Running prompt_lookup_decoding_lm C++ sample...
./build/samples/cpp/text_generation/prompt_lookup_decoding_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_prompt_lookup.txt
echo Running greedy_causal_lm C++ sample...
./build/samples/cpp/text_generation/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_greedy.txt
echo Running prompt_lookup_decoding_lm Python sample...
python ./samples/python/text_generation/prompt_lookup_decoding_lm.py ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_py.txt
echo All samples executed, checking result correctness...
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
with open('predictions_prompt_lookup.txt', 'r') as f:
predicted_prompt_lookup = f.readline()
with open('predictions_py.txt', 'r') as f:
predicted_prompt_lookup_py = f.readline()

print(f'Predicted greedy: {predicted_greedy}')
print(f'Predicted prompt lookup: {predicted_prompt_lookup}')
assert predicted_greedy == predicted_prompt_lookup
assert predicted_greedy == predicted_prompt_lookup_py
assert predicted_prompt_lookup == predicted_prompt_lookup_py
Expand Down
57 changes: 25 additions & 32 deletions src/cpp/src/continuous_batching_adapter.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "llm_pipeline_base.hpp"

#include "openvino/genai/continuous_batching_pipeline.hpp"
#include <memory>

namespace ov::genai {

Expand All @@ -17,29 +17,27 @@ template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
ContinuousBatchingPipeline m_impl;
std::unique_ptr<ContinuousBatchingPipeline> m_impl;
public:
ContinuousBatchingAdapter(
const ov::InferRequest& request,
const Tokenizer& tokenizer,
OptionalGenerationConfig generation_config
): LLMPipelineImplBase{dont_construct(), GenerationConfig{}},
m_impl{std::filesystem::path{}, SchedulerConfig{}, std::string{}} { }

): LLMPipelineImplBase{dont_construct(), GenerationConfig{}},
m_impl{std::make_unique<ContinuousBatchingPipeline>(std::filesystem::path{}, SchedulerConfig{}, std::string{})} { }
ContinuousBatchingAdapter(
const std::filesystem::path& models_path,
const Tokenizer& tokenizer,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config
): LLMPipelineImplBase{tokenizer, GenerationConfig()}, m_impl{
models_path,
tokenizer,
scheduler_config,
device,
plugin_config} {
m_generation_config = m_impl.get_config();
}
): LLMPipelineImplBase{tokenizer, GenerationConfig()} {
auto mutable_plugin_config = plugin_config;
mutable_plugin_config["sampler_num_threads"] = 1;
m_impl = std::make_unique<ContinuousBatchingPipeline>(models_path, tokenizer, scheduler_config, device, mutable_plugin_config);
m_generation_config = m_impl->get_config();
}

ContinuousBatchingAdapter(
const std::string& model_str,
Expand All @@ -49,27 +47,22 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
const std::string& device,
const ov::AnyMap& plugin_config,
const ov::genai::GenerationConfig& generation_config
): LLMPipelineImplBase{tokenizer, GenerationConfig()}, m_impl{
model_str,
weights_tensor,
tokenizer,
scheduler_config,
device,
plugin_config,
generation_config} {}
): LLMPipelineImplBase{tokenizer, GenerationConfig()} {
auto mutable_plugin_config = plugin_config;
mutable_plugin_config["sampler_num_threads"] = 1;
m_impl = std::make_unique<ContinuousBatchingPipeline>(model_str, weights_tensor, tokenizer, scheduler_config, device, mutable_plugin_config, generation_config);
}

ContinuousBatchingAdapter(
const std::filesystem::path& models_path,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config
): LLMPipelineImplBase{Tokenizer(models_path), GenerationConfig()}, m_impl{
models_path,
m_tokenizer,
scheduler_config,
device,
plugin_config} {
m_generation_config = m_impl.get_config();
): LLMPipelineImplBase{Tokenizer(models_path), GenerationConfig()} {
auto mutable_plugin_config = plugin_config;
mutable_plugin_config["sampler_num_threads"] = 1;
m_impl = std::make_unique<ContinuousBatchingPipeline>(models_path, m_tokenizer, scheduler_config, device, mutable_plugin_config);
m_generation_config = m_impl->get_config();
}

DecodedResults generate(
Expand All @@ -90,7 +83,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
}, inputs);
const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config;
// -1 == config.eos_token_id and config.validate() are handled in m_impl.
std::vector<GenerationResult> generated = m_impl.generate(prompts,
std::vector<GenerationResult> generated = m_impl->generate(prompts,
std::vector<GenerationConfig>{prompts.size(), config},
streamer
);
Expand Down Expand Up @@ -181,7 +174,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {

const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config;
// -1 == config.eos_token_id and config.validate() are handled in m_impl.
std::vector<EncodedGenerationResult> generated = m_impl.generate(input_ids,
std::vector<EncodedGenerationResult> generated = m_impl->generate(input_ids,
std::vector<GenerationConfig>{input_ids.size(), config},
streamer
);
Expand Down Expand Up @@ -210,11 +203,11 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
}

void start_chat(const std::string& system_message) override {
m_impl.start_chat();
m_impl->start_chat();
};

void finish_chat() override {
m_impl.finish_chat();
m_impl->finish_chat();
};
};

Expand Down
19 changes: 14 additions & 5 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,25 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
const std::vector<KVHeadConfig>& kv_cache_config) {
ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model;
ov::AnyMap mutable_properties = properties;
// Extract sampler_num_threads property if exists and remove it from properties
size_t sampler_num_threads = std::thread::hardware_concurrency();
auto sampler_num_threads_it = mutable_properties.find("sampler_num_threads");
if (sampler_num_threads_it != mutable_properties.end()) {
sampler_num_threads = sampler_num_threads_it->second.as<size_t>();
mutable_properties.erase(sampler_num_threads_it);
}

// TODO: remove once plugin automatically set KV cache precisions
apply_kv_cache_precision(model, device, properties);
apply_kv_cache_precision(model, device, mutable_properties);

// apply LoRA
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
if (auto filtered_properties = extract_adapters_from_properties(mutable_properties, &m_generation_config.adapters)) {
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device); // TODO: Make the prefix name configurable
compiled_model = core.compile_model(model, device, *filtered_properties);
} else {
compiled_model = core.compile_model(model, device, properties);
compiled_model = core.compile_model(model, device, mutable_properties);
}

ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention");
Expand Down Expand Up @@ -227,7 +235,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
std::make_shared<ModelRunner>(infer_request, m_block_size, m_num_decoder_layers);
}

m_sampler = std::make_shared<Sampler>(m_tokenizer);
m_sampler = std::make_shared<Sampler>(m_tokenizer, sampler_num_threads);
m_sampler->set_seed(m_generation_config.rng_seed);

// If eos_token_id was not provided, take value
Expand Down Expand Up @@ -282,8 +290,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {

_pull_awaiting_requests();


Scheduler::Output scheduler_output;

{
static ManualTimer scheduling_timer("scheduling");
scheduling_timer.start();
Expand Down Expand Up @@ -318,6 +326,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
return;
}
ov::Tensor logits;

{
static ManualTimer timer("forward");
timer.start();
Expand Down
Loading
Loading