Skip to content

Commit

Permalink
Merge branch 'master' into yt/fix-tiny-random-llava-next
Browse files Browse the repository at this point in the history
  • Loading branch information
yatarkan committed Jan 31, 2025
2 parents 1a0e35b + 38ab055 commit 004f598
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 205 deletions.
14 changes: 13 additions & 1 deletion .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,22 @@ jobs:
- name: run and compare
run: |
source ./ov/setupvars.sh
echo Running speculative_decoding_lm C++ sample...
./build/samples/cpp/text_generation/speculative_decoding_lm ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_speculative.txt
echo Running greedy_causal_lm C++ sample...
./build/samples/cpp/text_generation/greedy_causal_lm ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt
echo Running speculative_decoding_lm Python sample...
python ./samples/python/text_generation/speculative_decoding_lm.py ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_py.txt
echo All samples executed, checking result correctness...
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
with open('predictions_speculative.txt', 'r') as f:
predicted_speculative = f.readline()
with open('predictions_py.txt', 'r') as f:
predicted_py = f.readline()
print(f'Predicted greedy: {predicted_greedy}')
print(f'Predicted speculative: {predicted_speculative}')
assert predicted_greedy == predicted_speculative
assert predicted_greedy == predicted_py
assert predicted_speculative == predicted_py
Expand Down Expand Up @@ -523,17 +529,23 @@ jobs:
```
Question: Can you please add 2 and 3
A:' > ./prompt.txt
echo Running prompt_lookup_decoding_lm C++ sample...
./build/samples/cpp/text_generation/prompt_lookup_decoding_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_prompt_lookup.txt
echo Running greedy_causal_lm C++ sample...
./build/samples/cpp/text_generation/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_greedy.txt
echo Running prompt_lookup_decoding_lm Python sample...
python ./samples/python/text_generation/prompt_lookup_decoding_lm.py ./TinyLlama-1.1B-Chat-v1.0/ "$(<prompt.txt)" > predictions_py.txt
echo All samples executed, checking result correctness...
python -c "
with open('predictions_greedy.txt', 'r') as f:
predicted_greedy = f.readline()
with open('predictions_prompt_lookup.txt', 'r') as f:
predicted_prompt_lookup = f.readline()
with open('predictions_py.txt', 'r') as f:
predicted_prompt_lookup_py = f.readline()
print(f'Predicted greedy: {predicted_greedy}')
print(f'Predicted prompt lookup: {predicted_prompt_lookup}')
assert predicted_greedy == predicted_prompt_lookup
assert predicted_greedy == predicted_prompt_lookup_py
assert predicted_prompt_lookup == predicted_prompt_lookup_py
Expand Down
57 changes: 25 additions & 32 deletions src/cpp/src/continuous_batching_adapter.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include "llm_pipeline_base.hpp"

#include "openvino/genai/continuous_batching_pipeline.hpp"
#include <memory>

namespace ov::genai {

Expand All @@ -17,29 +17,27 @@ template<class... Ts> struct overloaded : Ts... {using Ts::operator()...;};
template<class... Ts> overloaded(Ts...) -> overloaded<Ts...>;

class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
ContinuousBatchingPipeline m_impl;
std::unique_ptr<ContinuousBatchingPipeline> m_impl;
public:
ContinuousBatchingAdapter(
const ov::InferRequest& request,
const Tokenizer& tokenizer,
OptionalGenerationConfig generation_config
): LLMPipelineImplBase{dont_construct(), GenerationConfig{}},
m_impl{std::filesystem::path{}, SchedulerConfig{}, std::string{}} { }

): LLMPipelineImplBase{dont_construct(), GenerationConfig{}},
m_impl{std::make_unique<ContinuousBatchingPipeline>(std::filesystem::path{}, SchedulerConfig{}, std::string{})} { }
ContinuousBatchingAdapter(
const std::filesystem::path& models_path,
const Tokenizer& tokenizer,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config
): LLMPipelineImplBase{tokenizer, GenerationConfig()}, m_impl{
models_path,
tokenizer,
scheduler_config,
device,
plugin_config} {
m_generation_config = m_impl.get_config();
}
): LLMPipelineImplBase{tokenizer, GenerationConfig()} {
auto mutable_plugin_config = plugin_config;
mutable_plugin_config["sampler_num_threads"] = 1;
m_impl = std::make_unique<ContinuousBatchingPipeline>(models_path, tokenizer, scheduler_config, device, mutable_plugin_config);
m_generation_config = m_impl->get_config();
}

ContinuousBatchingAdapter(
const std::string& model_str,
Expand All @@ -49,27 +47,22 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
const std::string& device,
const ov::AnyMap& plugin_config,
const ov::genai::GenerationConfig& generation_config
): LLMPipelineImplBase{tokenizer, GenerationConfig()}, m_impl{
model_str,
weights_tensor,
tokenizer,
scheduler_config,
device,
plugin_config,
generation_config} {}
): LLMPipelineImplBase{tokenizer, GenerationConfig()} {
auto mutable_plugin_config = plugin_config;
mutable_plugin_config["sampler_num_threads"] = 1;
m_impl = std::make_unique<ContinuousBatchingPipeline>(model_str, weights_tensor, tokenizer, scheduler_config, device, mutable_plugin_config, generation_config);
}

ContinuousBatchingAdapter(
const std::filesystem::path& models_path,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config
): LLMPipelineImplBase{Tokenizer(models_path), GenerationConfig()}, m_impl{
models_path,
m_tokenizer,
scheduler_config,
device,
plugin_config} {
m_generation_config = m_impl.get_config();
): LLMPipelineImplBase{Tokenizer(models_path), GenerationConfig()} {
auto mutable_plugin_config = plugin_config;
mutable_plugin_config["sampler_num_threads"] = 1;
m_impl = std::make_unique<ContinuousBatchingPipeline>(models_path, m_tokenizer, scheduler_config, device, mutable_plugin_config);
m_generation_config = m_impl->get_config();
}

DecodedResults generate(
Expand All @@ -90,7 +83,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
}, inputs);
const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config;
// -1 == config.eos_token_id and config.validate() are handled in m_impl.
std::vector<GenerationResult> generated = m_impl.generate(prompts,
std::vector<GenerationResult> generated = m_impl->generate(prompts,
std::vector<GenerationConfig>{prompts.size(), config},
streamer
);
Expand Down Expand Up @@ -181,7 +174,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {

const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config;
// -1 == config.eos_token_id and config.validate() are handled in m_impl.
std::vector<EncodedGenerationResult> generated = m_impl.generate(input_ids,
std::vector<EncodedGenerationResult> generated = m_impl->generate(input_ids,
std::vector<GenerationConfig>{input_ids.size(), config},
streamer
);
Expand Down Expand Up @@ -210,11 +203,11 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
}

void start_chat(const std::string& system_message) override {
m_impl.start_chat();
m_impl->start_chat();
};

void finish_chat() override {
m_impl.finish_chat();
m_impl->finish_chat();
};
};

Expand Down
19 changes: 14 additions & 5 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,25 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
const std::vector<KVHeadConfig>& kv_cache_config) {
ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model;
ov::AnyMap mutable_properties = properties;
// Extract sampler_num_threads property if exists and remove it from properties
size_t sampler_num_threads = std::thread::hardware_concurrency();
auto sampler_num_threads_it = mutable_properties.find("sampler_num_threads");
if (sampler_num_threads_it != mutable_properties.end()) {
sampler_num_threads = sampler_num_threads_it->second.as<size_t>();
mutable_properties.erase(sampler_num_threads_it);
}

// TODO: remove once plugin automatically set KV cache precisions
apply_kv_cache_precision(model, device, properties);
apply_kv_cache_precision(model, device, mutable_properties);

// apply LoRA
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
if (auto filtered_properties = extract_adapters_from_properties(mutable_properties, &m_generation_config.adapters)) {
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device); // TODO: Make the prefix name configurable
compiled_model = core.compile_model(model, device, *filtered_properties);
} else {
compiled_model = core.compile_model(model, device, properties);
compiled_model = core.compile_model(model, device, mutable_properties);
}

ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention");
Expand Down Expand Up @@ -227,7 +235,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
std::make_shared<ModelRunner>(infer_request, m_block_size, m_num_decoder_layers);
}

m_sampler = std::make_shared<Sampler>(m_tokenizer);
m_sampler = std::make_shared<Sampler>(m_tokenizer, sampler_num_threads);
m_sampler->set_seed(m_generation_config.rng_seed);

// If eos_token_id was not provided, take value
Expand Down Expand Up @@ -282,8 +290,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {

_pull_awaiting_requests();


Scheduler::Output scheduler_output;

{
static ManualTimer scheduling_timer("scheduling");
scheduling_timer.start();
Expand Down Expand Up @@ -318,6 +326,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() {
return;
}
ov::Tensor logits;

{
static ManualTimer timer("forward");
timer.start();
Expand Down
Loading

0 comments on commit 004f598

Please sign in to comment.