Skip to content

Commit

Permalink
fix tests and whisper callback
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Feb 5, 2025
1 parent 09e4755 commit e75b81a
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 26 deletions.
1 change: 0 additions & 1 deletion src/cpp/include/openvino/genai/generation_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ GenerationHandleImpl {

void cancel();

GenerationOutputs back();
// Reads result of a generation for single iteration
GenerationOutputs read();
// Reads all generated tokens for all sequences
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/include/openvino/genai/whisper_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class OPENVINO_GENAI_EXPORTS ChunkStreamerBase : public StreamerBase {

// Return flag corresponds whether generation should be stopped: false means continue generation, true means stop.
using ChunkStreamerVariant =
std::variant<std::function<bool(std::string)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;
std::variant<std::function<bool(std::string)>, std::function<StreamingStatus(std::string)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;

struct OPENVINO_GENAI_EXPORTS WhisperRawPerfMetrics {
/** @brief Duration for each features extraction call */
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/whisper/streamer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ChunkTextCallbackStreamer : private TextCallbackStreamer, public ChunkStre
StreamingStatus write_chunk(std::vector<int64_t> tokens) override;
void end() override;

ChunkTextCallbackStreamer(const Tokenizer& tokenizer, std::function<bool(std::string)> callback)
ChunkTextCallbackStreamer(const Tokenizer& tokenizer, std::function<ov::genai::CallbackTypeVariant(std::string)> callback)
: TextCallbackStreamer(tokenizer, callback){};
};

Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ WhisperGenerateResult whisper_generate(const ov::genai::WhisperGenerationConfig&
extracted_segments.non_timestamp_tokens.begin(),
extracted_segments.non_timestamp_tokens.end());

if (streamer && streamer->put_chunk(extracted_segments.non_timestamp_tokens)) {
if (streamer && streamer->write_chunk(extracted_segments.non_timestamp_tokens) != ov::genai::StreamingStatus::RUNNING) {
cancelled = true;
break;
}
Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/whisper_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ ov::genai::ChunkStreamerVariant get_chunk_streamer_from_map(const ov::AnyMap& co
streamer = any_val.as<std::shared_ptr<ov::genai::ChunkStreamerBase>>();
} else if (any_val.is<std::function<bool(std::string)>>()) {
streamer = any_val.as<std::function<bool(std::string)>>();
} else if (any_val.is<std::function<ov::genai::StreamingStatus(std::string)>>()) {
streamer = any_val.as<std::function<ov::genai::StreamingStatus(std::string)>>();
}
}
return streamer;
Expand Down Expand Up @@ -88,6 +90,8 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<ChunkTextCallbackStreamer>(m_tokenizer, *callback);
} else if (auto callback = std::get_if<std::function<StreamingStatus(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<ChunkTextCallbackStreamer>(m_tokenizer, *callback);
}

auto [context_tokens, tokenization_duration_microseconds] = prepare_context_tokens(config, m_tokenizer);
Expand Down Expand Up @@ -145,6 +149,8 @@ class WhisperPipeline::WhisperPipelineStatefulImpl : public WhisperPipeline::Whi
std::pair<std::string, Any> streamer(ChunkStreamerVariant func) {
if (auto streamer_obj = std::get_if<std::shared_ptr<ChunkStreamerBase>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::shared_ptr<ChunkStreamerBase>>(*streamer_obj)};
} else if (auto streamer_obj = std::get_if<std::function<StreamingStatus(std::string)>>(&func)) {
return {utils::STREAMER_ARG_NAME, Any::make<std::function<StreamingStatus(std::string)>>(*streamer_obj)};
} else {
auto callback = std::get<std::function<bool(std::string)>>(func);
return {utils::STREAMER_ARG_NAME, Any::make<std::function<bool(std::string)>>(callback)};
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/whisper_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ WhisperDecodedResults WhisperPipeline::StaticWhisperPipeline::generate(
streamer_ptr = *streamer_obj;
} else if (auto callback = std::get_if<std::function<bool(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<ChunkTextCallbackStreamer>(m_tokenizer, *callback);
} else if (auto callback = std::get_if<std::function<StreamingStatus(std::string)>>(&streamer)) {
streamer_ptr = std::make_shared<ChunkTextCallbackStreamer>(m_tokenizer, *callback);
}

size_t max_new_tokens = config.get_max_new_tokens();
Expand Down
12 changes: 10 additions & 2 deletions src/python/openvino_genai/py_openvino_genai.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,15 @@ class ChunkStreamerBase:
"""
def put_chunk(self, tokens: list[int]) -> bool:
"""
Put is called every time new token chunk is generated. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops
put_chunk is called every time new token chunk is generated. Returns a bool flag to indicate whether generation should be stopped, if return true generation stops
"""
def write(self, token: int) -> StreamingStatus:
"""
Write is called every time new token is generated. Returns a StreamingStatus flag to indicate whether generation should be stopped
"""
def write_chunk(self, tokens: list[int]) -> StreamingStatus:
"""
write_chunk is called every time new token chunk is generated. Returns a StreamingStatus flag to indicate whether generation should be stopped
"""
class ContinuousBatchingPipeline:
"""
Expand Down Expand Up @@ -2175,7 +2183,7 @@ class WhisperPipeline:
models_path (os.PathLike): Path to the model file.
device (str): Device to run the model on (e.g., CPU, GPU).
"""
def generate(self, raw_speech_input: list[float], generation_config: WhisperGenerationConfig | None = None, streamer: typing.Callable[[str], bool] | ChunkStreamerBase | None = None, **kwargs) -> WhisperDecodedResults:
def generate(self, raw_speech_input: list[float], generation_config: WhisperGenerationConfig | None = None, streamer: typing.Callable[[str], int | None] | ChunkStreamerBase | None = None, **kwargs) -> WhisperDecodedResults:
"""
High level generate that receives raw speech as a vector of floats and returns decoded output.
Expand Down
1 change: 0 additions & 1 deletion src/python/py_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ ov::genai::StreamerVariant pystreamer_to_streamer(const PyBindStreamerVariant& p
streamer = callback_wrapped;
},
[&streamer](std::shared_ptr<StreamerBase> streamer_cls){
std::cout << "streamer_cls " << std::endl;
streamer = streamer_cls;
},
[](std::monostate none){ /*streamer is already a monostate */ }
Expand Down
62 changes: 49 additions & 13 deletions src/python/py_whisper_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using ov::genai::OptionalWhisperGenerationConfig;
using ov::genai::PerfMetrics;
using ov::genai::RawSpeechInput;
using ov::genai::StreamerBase;
using ov::genai::StreamingStatus;
using ov::genai::StreamerVariant;
using ov::genai::Tokenizer;
using ov::genai::WhisperDecodedResultChunk;
Expand All @@ -31,7 +32,7 @@ using ov::genai::WhisperPerfMetrics;
using ov::genai::WhisperPipeline;
using ov::genai::WhisperRawPerfMetrics;
using PyBindChunkStreamerVariant =
std::variant<std::function<bool(py::str)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;
std::variant<std::function<std::optional<uint16_t>(py::str)>, std::shared_ptr<ChunkStreamerBase>, std::monostate>;

namespace pyutils = ov::genai::pybind::utils;

Expand Down Expand Up @@ -227,17 +228,31 @@ OptionalWhisperGenerationConfig update_whisper_config_from_kwargs(const Optional

class ConstructableChunkStreamer : public ChunkStreamerBase {
bool put(int64_t token) override {
PYBIND11_OVERRIDE_PURE(bool, // Return type
ChunkStreamerBase, // Parent class
put, // Name of function in C++ (must match Python name)
token // Argument(s)
PYBIND11_OVERRIDE(bool, // Return type
ChunkStreamerBase, // Parent class
put, // Name of function in C++ (must match Python name)
token // Argument(s)
);
}
StreamingStatus write(int64_t token) override {
PYBIND11_OVERRIDE(StreamingStatus, // Return type
ChunkStreamerBase, // Parent class
write, // Name of function in C++ (must match Python name)
token // Argument(s)
);
}
bool put_chunk(std::vector<int64_t> tokens) override {
PYBIND11_OVERRIDE_PURE(bool, // Return type
ChunkStreamerBase, // Parent class
put_chunk, // Name of function in C++ (must match Python name)
tokens // Argument(s)
PYBIND11_OVERRIDE(bool, // Return type
ChunkStreamerBase, // Parent class
put_chunk, // Name of function in C++ (must match Python name)
tokens // Argument(s)
);
}
StreamingStatus write_chunk(std::vector<int64_t> tokens) override {
PYBIND11_OVERRIDE(StreamingStatus, // Return type
ChunkStreamerBase, // Parent class
write_chunk, // Name of function in C++ (must match Python name)
tokens // Argument(s)
);
}
void end() override {
Expand All @@ -247,13 +262,24 @@ class ConstructableChunkStreamer : public ChunkStreamerBase {

ChunkStreamerVariant pystreamer_to_chunk_streamer(const PyBindChunkStreamerVariant& py_streamer) {
return std::visit(
pyutils::overloaded{[](const std::function<bool(py::str)>& py_callback) {
pyutils::overloaded{[](const std::function<std::optional<uint16_t>(py::str)>& py_callback) {
// Wrap python streamer with manual utf-8 decoding. Do not rely
// on pybind automatic decoding since it raises exceptions on incomplete
// strings.
return static_cast<ChunkStreamerVariant>([py_callback](std::string subword) -> bool {
return static_cast<ChunkStreamerVariant>([py_callback](std::string subword) -> ov::genai::StreamingStatus {
auto py_str = PyUnicode_DecodeUTF8(subword.data(), subword.length(), "replace");
return py_callback(py::reinterpret_borrow<py::str>(py_str));
std::optional<uint16_t> callback_output = py_callback(py::reinterpret_borrow<py::str>(py_str));
auto result = StreamingStatus::RUNNING;
if (callback_output.has_value()) {
if (*callback_output == (uint16_t)StreamingStatus::RUNNING) {
result = StreamingStatus::RUNNING;
} else if (*callback_output == (uint16_t)StreamingStatus::CANCEL) {
result = StreamingStatus::CANCEL;
} else {
result = StreamingStatus::STOP;
}
}
return result;
});
},
[](std::shared_ptr<ChunkStreamerBase> streamer_cls) {
Expand Down Expand Up @@ -297,11 +323,21 @@ void init_whisper_pipeline(py::module_& m) {
"Put is called every time new token is generated. Returns a bool flag to indicate whether generation "
"should be stopped, if return true generation stops",
py::arg("token"))
.def("write",
&ChunkStreamerBase::write,
"Write is called every time new token is generated. Returns a StreamingStatus flag to indicate whether generation "
"should be stopped",
py::arg("token"))
.def("put_chunk",
&ChunkStreamerBase::put_chunk,
"Put is called every time new token chunk is generated. Returns a bool flag to indicate whether "
"put_chunk is called every time new token chunk is generated. Returns a bool flag to indicate whether "
"generation should be stopped, if return true generation stops",
py::arg("tokens"))
.def("write_chunk",
&ChunkStreamerBase::write_chunk,
"write_chunk is called every time new token chunk is generated. Returns a StreamingStatus flag to indicate whether "
"generation should be stopped",
py::arg("tokens"))
.def("end",
&ChunkStreamerBase::end,
"End is called at the end of generation. It can be used to flush cache if your own streamer has one");
Expand Down
13 changes: 8 additions & 5 deletions tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,15 @@ def convert_to_hf(
kwargs['bos_token_id'] = default_generation_config.bos_token_id
kwargs['pad_token_id'] = default_generation_config.pad_token_id

if len(generation_config.stop_token_ids) > 0:
kwargs['eos_token_id'] = list(generation_config.stop_token_ids)
elif generation_config.eos_token_id != -1:
kwargs['eos_token_id'] = generation_config.eos_token_id
if (generation_config.ignore_eos):
kwargs['eos_token_id'] = []
else:
kwargs['eos_token_id'] = default_generation_config.eos_token_id
if len(generation_config.stop_token_ids) > 0:
kwargs['eos_token_id'] = list(generation_config.stop_token_ids)
elif generation_config.eos_token_id != -1:
kwargs['eos_token_id'] = generation_config.eos_token_id
else:
kwargs['eos_token_id'] = default_generation_config.eos_token_id

# copy penalties
kwargs['repetition_penalty'] = generation_config.repetition_penalty
Expand Down
2 changes: 1 addition & 1 deletion tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_chat_scenario_callback_cancel(model_descr):
'What was my first question?'
]

generation_config_kwargs = dict(max_new_tokens=20, ignore_eos=True)
generation_config_kwargs = dict(max_new_tokens=20)

chat_history_hf = []
chat_history_ov = []
Expand Down

0 comments on commit e75b81a

Please sign in to comment.