Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a choice of how to end streaming from callback: STOP or CANCEL #1476

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions samples/cpp/text_generation/chat_sample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ int main(int argc, char* argv[]) try {

ov::genai::GenerationConfig config;
config.max_new_tokens = 100;
std::function<bool(std::string)> streamer = [](std::string word) {

auto streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
// false means continue generation.
return false;
return ov::genai::StreamingStatus::RUNNING;
};

pipe.start_chat();
Expand Down
2 changes: 1 addition & 1 deletion samples/cpp/text_generation/multinomial_causal_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char* argv[]) try {
config.top_k = 30;
auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return false;
return ov::genai::StreamingStatus::RUNNING;
};

// Since the streamer is set, the results will
Expand Down
2 changes: 1 addition & 1 deletion samples/cpp/text_generation/prompt_lookup_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ int main(int argc, char* argv[]) try {

auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return false;
return ov::genai::StreamingStatus::RUNNING;
};

// Since the streamer is set, the results will
Expand Down
2 changes: 1 addition & 1 deletion samples/cpp/text_generation/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ int main(int argc, char* argv[]) try {

auto streamer = [](std::string subword) {
std::cout << subword << std::flush;
return false;
return ov::genai::StreamingStatus::RUNNING;
};

// Since the streamer is set, the results will
Expand Down
4 changes: 1 addition & 3 deletions samples/python/text_generation/chat_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False

return openvino_genai.StreamingStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
23 changes: 12 additions & 11 deletions samples/python/text_generation/multinomial_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def __next__(self):

def get_stop_flag(self):
"""
Checks whether the generation process should be stopped.
Checks whether the generation process should be stopped or cancelled.

Returns:
bool: Always returns False in this implementation.
openvino_genai.StreamingStatus: Always returns RUNNING in this implementation.
"""
return False
return openvino_genai.StreamingStatus.RUNNING

def put_word(self, word: str):
"""
Expand All @@ -72,7 +72,7 @@ def put_word(self, word: str):
"""
self.text_queue.put(word)

def put(self, token_id: int) -> bool:
def write(self, token_id: int) -> openvino_genai.StreamingStatus:
"""
Processes a token and manages the decoding buffer. Adds decoded text to the queue.

Expand Down Expand Up @@ -106,12 +106,12 @@ def put(self, token_id: int) -> bool:
self.print_len = print_until
self.put_word(word)

if self.get_stop_flag():
stop_flag = self.get_stop_flag()
if stop_flag != openvino_genai.StreamingStatus.RUNNING:
# When generation is stopped from streamer then end is not called, need to call it here manually.
self.end()
return True # True means stop generation
else:
return False # False means continue generation

return stop_flag

def end(self):
"""
Expand All @@ -123,6 +123,7 @@ def end(self):
self.put_word(word)
self.tokens_cache = []
self.print_len = 0
self.put_word('\n')
self.put_word(None)


Expand All @@ -132,12 +133,12 @@ def __init__(self, tokenizer, tokens_len):
super().__init__(tokenizer)
self.tokens_len = tokens_len

def put(self, token_id: int) -> bool:
def write(self, token_id: int) -> openvino_genai.StreamingStatus:
if (len(self.tokens_cache) + 1) % self.tokens_len != 0:
self.tokens_cache.append(token_id)
self.decoded_lengths.append(-1)
return False
return super().put(token_id)
return openvino_genai.StreamingStatus.RUNNING
return super().write(token_id)


def main():
Expand Down
9 changes: 4 additions & 5 deletions samples/python/text_generation/prompt_lookup_decoding_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import argparse
import openvino_genai

def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False
def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
return openvino_genai.StreamingStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
3 changes: 1 addition & 2 deletions samples/python/text_generation/speculative_decoding_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
def streamer(subword):
print(subword, end='', flush=True)
# Return flag corresponds whether generation should be stopped.
# False means continue generation.
return False
return openvino_genai.StreamingStatus.RUNNING

def main():
parser = argparse.ArgumentParser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def streamer(subword: str) -> bool:
print(subword, end='', flush=True)

# No value is returned as in this example we don't want to stop the generation in this method.
# "return None" will be treated the same as "return False".
# "return None" will be treated the same as "return openvino_genai.StreamingStatus.RUNNING".


def read_image(path: str) -> Tensor:
Expand Down
3 changes: 1 addition & 2 deletions src/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ int main(int argc, char* argv[]) {
auto streamer = [](std::string word) {
std::cout << word << std::flush;
// Return flag corresponds whether generation should be stopped.
// false means continue generation.
return false;
return ov::genai::StreamingStatus::RUNNING;
};
std::cout << pipe.generate("The Sun is yellow because", ov::genai::streamer(streamer), ov::genai::max_new_tokens(200));
}
Expand Down
24 changes: 19 additions & 5 deletions src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
#include "openvino/genai/perf_metrics.hpp"

namespace ov::genai {

enum class GenerationStatus {
RUNNING = 0, // Default status for ongoing generation
FINISHED = 1, // Status set when generation has been finished
IGNORED = 2, // Status set when generation run into out-of-memory condition and could not be continued
DROPPED_BY_PIPELINE = 3, // Currently not used, TODO: implement abort functionality
DROPPED_BY_HANDLE = 4 // Status set when generation handle is dropped
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
CANCEL = 3, // Status set when generation handle is cancelled. The last prompt and all generated tokens will be dropped from history, KV cache will include history but last step.
STOP = 4, // Status set when generation handle is stopped. History will be kept, KV cache will include the last prompt and generated tokens.
DROPPED_BY_HANDLE OPENVINO_ENUM_DEPRECATED("Please, use `STOP` instead of `DROPPED_BY_HANDLE`.") = GenerationStatus::STOP // Status set when generation handle is dropped.
};


struct EncodedGenerationResult {
// request ID - obsolete when handle API is approved as handle will connect results with prompts.
uint64_t m_request_id;
Expand Down Expand Up @@ -70,10 +73,10 @@ using GenerationOutputs = std::unordered_map<uint64_t, GenerationOutput>;

class GenerationStream;

class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
class OPENVINO_GENAI_EXPORTS
GenerationHandleImpl {
std::shared_ptr<GenerationStream> m_generation_stream;
ov::genai::GenerationConfig m_sampling_params;

ov::genai::GenerationConfig m_sampling_params;
public:
GenerationHandleImpl(std::shared_ptr<GenerationStream> generation_stream, const ov::genai::GenerationConfig& sampling_params) :
m_generation_stream(std::move(generation_stream)),
Expand All @@ -88,10 +91,21 @@ class OPENVINO_GENAI_EXPORTS GenerationHandleImpl {
GenerationStatus get_status();

bool can_read();

OPENVINO_DEPRECATED("Please, use `stop()` instead of `drop()`. Support will be removed in 2026.0.0 release.")
bool is_dropped();
Wovchena marked this conversation as resolved.
Show resolved Hide resolved

bool is_stopped();

bool is_cancelled();

OPENVINO_DEPRECATED("Please, use `stop()` instead of `drop()`. Support will be removed in 2026.0.0 release.")
void drop();

void stop();

void cancel();

// Reads result of a generation for single iteration
GenerationOutputs read();
// Reads all generated tokens for all sequences
Expand Down
6 changes: 4 additions & 2 deletions src/cpp/include/openvino/genai/llm_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
namespace ov {
namespace genai {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
// Return flag corresponds whether generation should be stopped. It could be:
// ov::genai::StreamingStatus flag, RUNNING means continue generation, STOP means stop generation, CANCEL means stop generation and remove last propmt and answer from history
// *DEPRECATED* bool flag, false means continue generation, true means stop. Please, use `ov::genai::StreamingStatus` instead.
using StreamerVariant = std::variant<std::function<bool(std::string)>, std::function<StreamingStatus(std::string)>, std::shared_ptr<StreamerBase>, std::monostate>;
Wovchena marked this conversation as resolved.
Show resolved Hide resolved
using OptionalGenerationConfig = std::optional<GenerationConfig>;
using EncodedInputs = std::variant<ov::Tensor, TokenizedInputs>;
using StringInputs = std::variant<std::string, std::vector<std::string>>;
Expand Down
21 changes: 19 additions & 2 deletions src/cpp/include/openvino/genai/streamer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,37 @@
#pragma once

#include "openvino/genai/tokenizer.hpp"
#include <variant>

namespace ov {
namespace genai {

enum class StreamingStatus {
RUNNING = 0, // Continue to run of inference
STOP = 1, // Stop generation, keep history as is, KV cache includes last request and generated tokens
CANCEL = 2 // Stop generate, drop last prompt and all generated tokens from history, KV cache includes history but last step
};

/**
* @brief base class for streamers. In order to use inherit from from this class and implement put, and methods
*
* @param m_tokenizer tokenizer
*/
class OPENVINO_GENAI_EXPORTS StreamerBase {
public:
/// @brief put is called every time new token is decoded,
/// @brief put is called every time new token is decoded. Deprecated. Please, use write instead.
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can add a new function put_token / write with new return type and deprecate current put with binary status?
IMO, it will be more future proof and removes ambiguity that authors of custom text streamers need to write functions like is_generation_complete

CC @Wovchena @sbalandi @as-suvorov what do you think?

BTW, if you are OK with new method, note, that we need to select more or less generic name, which will allow to put a single token or multiple tokens (Whisper / Spec Dec cases)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have added new method, called it write . put_token good for one token, but for chunk put_chunk is used now and I'm not sure chunk is common name or it's okay change it to tokens ? @as-suvorov what do you think ?
Also maybe post is possible ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think write is fine

virtual bool put(int64_t token) = 0;
OPENVINO_DEPRECATED("Please, use `write()` instead of `put()`. Support will be removed in 2026.0.0 release.")
virtual bool put(int64_t token) {
OPENVINO_THROW("This method is deprecated and will be removed in 2026.0.0 release. Please, override write() insted.");
return true;
};

/// @brief write is called every time new token is decoded
/// @return StreamingStatus flag to indicate whether generation should be countinue to run or stopped or cancelled
virtual StreamingStatus write(int64_t token) {
return put(token) ? StreamingStatus::STOP : StreamingStatus::RUNNING;
};

/// @brief end is called at the end of generation. It can be used to flush cache if your own streamer has one
virtual void end() = 0;
Expand Down
15 changes: 12 additions & 3 deletions src/cpp/include/openvino/genai/whisper_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,23 @@ using RawSpeechInput = std::vector<float>;
*/
class OPENVINO_GENAI_EXPORTS ChunkStreamerBase : public StreamerBase {
public:
/// @brief put is called every time new token chunk is generated,
/// @brief put_chunk is called every time new token chunk is generated,
/// @return bool flag to indicate whether generation should be stopped, if return true generation stops
virtual bool put_chunk(std::vector<int64_t> tokens) = 0;
virtual bool put_chunk(std::vector<int64_t> tokens) {
OPENVINO_THROW("This method is deprecated and will be removed in 2026.0.0 release. Please, override write_chunk() insted.");
return true;
}

/// @brief write_chunk is called every time new token chunk is generated
/// @return StreamingStatus flag to indicate whether generation should be stopped
virtual StreamingStatus write_chunk(std::vector<int64_t> tokens) {
return put_chunk(tokens) ? StreamingStatus::STOP : StreamingStatus::RUNNING;
}
};

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using ChunkStreamerVariant =
std::variant<std::function<bool(std::string)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;
std::variant<std::function<bool(std::string)>, std::function<StreamingStatus(std::string)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;

struct OPENVINO_GENAI_EXPORTS WhisperRawPerfMetrics {
/** @brief Duration for each features extraction call */
Expand Down
4 changes: 2 additions & 2 deletions src/cpp/src/continuous_batching_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
std::vector<std::string> plain_replies;
std::vector<float> plain_scores;
for (GenerationResult& res : generated) {
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus");
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_replies));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
Expand Down Expand Up @@ -182,7 +182,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase {
std::vector<std::vector<int64_t>> plain_tokens;
std::vector<float> plain_scores;
for (EncodedGenerationResult& res : generated) {
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::DROPPED_BY_HANDLE, "Got unfinished GenerationStatus");
OPENVINO_ASSERT(res.m_status == GenerationStatus::FINISHED || res.m_status == GenerationStatus::STOP || res.m_status == GenerationStatus::CANCEL, "Got unfinished GenerationStatus");
std::move(res.m_generation_ids.begin(), res.m_generation_ids.end(), std::back_inserter(plain_tokens));
std::move(res.m_scores.begin(), res.m_scores.end(), std::back_inserter(plain_scores));
}
Expand Down
22 changes: 7 additions & 15 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,17 +429,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
}
set_adapters(sampling_params[0].adapters);

const std::shared_ptr<StreamerBase>& streamer_ptr = std::visit(overloaded{
[](std::monostate) -> std::shared_ptr<StreamerBase> {
return nullptr;
},
[](const std::shared_ptr<StreamerBase>& streamer) {
return streamer;
},
[this](const std::function<bool(std::string)>& streamer) -> std::shared_ptr<StreamerBase> {
return std::make_unique<TextCallbackStreamer>(m_tokenizer, streamer);
}
}, streamer);
const std::shared_ptr<StreamerBase>& streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer);

OPENVINO_ASSERT(streamer_ptr == nullptr || input_ids.size() == 1 && sampling_params[0].num_return_sequences == 1 &&
(sampling_params[0].is_greedy_decoding() || sampling_params[0].is_multinomial()),
Expand Down Expand Up @@ -476,8 +466,9 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
OPENVINO_ASSERT(generation_outputs.size() <= 1);
if (!generation_outputs.empty()) {
for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) {
if (streamer_ptr->put(generated_token_id)) {
generation->drop();
auto streaming_status = streamer_ptr->write(generated_token_id);
if (streaming_status != ov::genai::StreamingStatus::RUNNING) {
streaming_status == ov::genai::StreamingStatus::CANCEL ? generation->cancel() : generation->stop();
break;
}
}
Expand Down Expand Up @@ -537,6 +528,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vector<o
result.m_request_id = request_id;
result.m_generation_ids.resize(num_outputs);
result.m_scores.resize(num_outputs);
result.m_status = request->get_generation_stream()->get_status();

for (size_t i = 0; i < num_outputs; ++i) {
const auto & sequence = sequences[i];
Expand Down Expand Up @@ -571,7 +563,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_free_non_running_reque
std::vector<SequenceGroup::Ptr>::iterator requests_iterator = m_requests.begin();
while (requests_iterator != m_requests.end()) {
const auto& request = *requests_iterator;
if (request->has_finished() || request->handle_dropped()) {
if(request->has_finished() || request->handle_stopped() || request->handle_cancelled()) {
for (const auto& sequence: request->get_sequences()) {
if (m_scheduler->has_block_table(sequence->get_id())) {
m_scheduler->free_sequence(sequence->get_id());
Expand All @@ -589,7 +581,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_notify_requests_droppe
// Notify the last time by pushing empty output
// This causes read() to unblock by adding anything to the queue
for (SequenceGroup::Ptr& request : m_requests) {
if (request->handle_dropped())
if (request->handle_stopped() || request->handle_cancelled())
request->push_empty_outputs();
}
}
Expand Down
Loading
Loading