Skip to content

Commit

Permalink
Support left and right strip for Added token (#887)
Browse files Browse the repository at this point in the history
* Support left and right strip for Added token

* update the audio size data type

* quick fix

* add the unit test on data type verification

* fix the unit test

* Update comment and remove BOM in header
  • Loading branch information
wenbingl authored Feb 7, 2025
1 parent c24b7ba commit 48176ac
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 58 deletions.
11 changes: 7 additions & 4 deletions operators/tokenizer/bpe_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ std::vector<int64_t> KernelBpeTokenizer::Tokenize(ustring& input,
}

// Parse input
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input, added_tokens_);
bpe::PreTokenizerWithRegEx reg_splitter;
// NOTE: the pattern was already validated on loading json file.
// safe to ingore the return value here.
Expand Down Expand Up @@ -399,7 +399,7 @@ std::vector<int64_t> KernelBpeTokenizer::SpmTokenize(ustring& input,

size_t max_length = static_cast<size_t>(max_length_i64);
// Parse input
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input);
auto special_token_split_res = bbpe_tokenizer_->SplitByAddedAndSpecial(input, added_tokens_);
bool add_dummy_prefix = bpe_conf_.get().add_dummy_prefix_;

for (auto& seg_id : special_token_split_res) {
Expand Down Expand Up @@ -677,11 +677,13 @@ void JsonFastTokenizer::UpdateTokenizer(const TokenJsonConfig& config, const jso
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
added_tokens_.emplace_back(TokenJsonConfig::ParseAddedToken(token));
auto tok_extended = TokenJsonConfig::ParseAddedToken(token);
added_tokens_.emplace(ustring(tok_extended.content_), tok_extended);
}
}

for (const auto& added_token : added_tokens_) {
// iterate the added_tokens_ map and set the special tokens
for (const auto& [key, added_token] : added_tokens_) {
if (added_token.content_ == config.bos_token_) {
bos_token_id_ = added_token.id_;
} else if (added_token.content_ == config.eos_token_) {
Expand All @@ -692,6 +694,7 @@ void JsonFastTokenizer::UpdateTokenizer(const TokenJsonConfig& config, const jso
}

bbpe_tokenizer_->LoadAddedTokens(added_tokens_);

add_bos_token_ = config.add_bos_token_;
add_eos_token_ = config.add_eos_token_;

Expand Down
3 changes: 2 additions & 1 deletion operators/tokenizer/bpe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct BpeModelConf {

struct KernelBpeTokenizer {
using json = nlohmann::json;
using AddedTokenMap = ort_extensions::AddedTokenMap;
KernelBpeTokenizer(const BpeModelConf& conf);
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info);

Expand Down Expand Up @@ -53,6 +54,7 @@ struct KernelBpeTokenizer {

protected:
std::string model_name_;
AddedTokenMap added_tokens_;
std::reference_wrapper<BpeModelConf const> bpe_conf_;
std::unique_ptr<ort_extensions::BpeModel> bbpe_tokenizer_;

Expand Down Expand Up @@ -131,5 +133,4 @@ class JsonFastTokenizer : public KernelBpeTokenizer {
void UpdateTokenizer(const ort_extensions::TokenJsonConfig& config, const json& tok_json);

BpeModelConf json_conf_;
std::vector<ort_extensions::AddedToken> added_tokens_;
};
2 changes: 1 addition & 1 deletion operators/tokenizer/bpe_streaming.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class BpeStreamingDecoder : public KernelBpeDecoder {
spm_model_ = encoder.IsSpmModel();

const auto& a_toks = encoder.GetAddedTokens();
for (const auto& tok : a_toks) {
for (const auto&[key, tok] : a_toks) {
added_tokens_[tok.id_] = tok.content_;
if (tok.special_) {
all_special_ids_.insert(tok.id_);
Expand Down
51 changes: 46 additions & 5 deletions operators/tokenizer/bpe_tokenizer_model.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
Expand Down Expand Up @@ -310,8 +310,8 @@ class BpeModel {
return {};
}

void LoadAddedTokens(const std::vector<AddedToken>& added_tokens) {
for (const auto& token : added_tokens) {
void LoadAddedTokens(const AddedTokenMap& added_tokens) {
for (const auto& [key, token] : added_tokens) {
added_tokens_.Add(ustring(token.content_), 0, token.id_);
}
}
Expand All @@ -320,13 +320,54 @@ class BpeModel {

// REF:
// https://github.com/huggingface/transformers/blob/c9e72f55b2dc4b9be4edb986dce0552582b328f2/src/transformers/tokenization_utils.py#L52
bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input) const {
bpe::TokenPairs SplitByAddedAndSpecial(const ustring& input, const AddedTokenMap& t_map) const {
static const std::set<char32_t> ws_chars = {U' ', U'\n', U'\r', U'\t'};
// split by added tokens
bpe::TokenPairs added_result;
bpe::TokenPairs final_result;
added_tokens_.Split(input, added_result);
for (const auto& [token, id] : added_result) {

for (size_t n = 0; n < added_result.size(); ++n) {
auto& [token, id] = added_result[n];
bool has_left = n > 0;
bool has_right = n < added_result.size() - 1;

if (id != bpe::kInvalidTokenId) {
if (has_left || has_right) {
auto iter_tok_extend = t_map.find(std::u32string(token));
if (iter_tok_extend != t_map.end()) {
if (has_right && iter_tok_extend->second.rstrip_) {
auto& [next_token, next_id] = added_result[n + 1];
// r-strip removes trailing characters from right side, which is equivalent to removing whitespace from left side of next token
if (next_id == bpe::kInvalidTokenId) {
final_result.emplace_back(token, id);
size_t pos = 0;
while (pos < next_token.size() && ws_chars.count(next_token[pos])) {
pos++;
}
auto stripped_token = next_token.substr(pos);
final_result.emplace_back(stripped_token, next_id);
n += 1;
continue;
}
}
if (has_left && iter_tok_extend->second.lstrip_) {
auto& [prev_token, prev_id] = added_result[n - 1];
// l-strip means remove whitespaces from right side of previous token
if (prev_id == bpe::kInvalidTokenId) {
size_t pos = token.size();
while (pos > 0 && ws_chars.count(token[pos - 1])) {
pos--;
}
auto stripped_token = token.substr(0, pos);
final_result.back().first = stripped_token;
final_result.emplace_back(token, id);
continue;
}
}
}
}
// if not additional processing, just add it to final result
final_result.emplace_back(token, id);
} else {
auto special_result = special_tokens_.SplitBySpecialTokens(token);
Expand Down
3 changes: 3 additions & 0 deletions operators/tokenizer/tokenizer_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <string>
#include <string_view>

#include "ortx_tokenizer.h"
#include "ext_status.h"
Expand Down Expand Up @@ -34,5 +35,7 @@ struct TokenizerDecodingState {
std::string incomplete_utf8_;
};

using AddedTokenMap = std::unordered_map<std::u32string, AddedToken>;

constexpr std::string_view spm_escaped_space = "\xE2\x96\x81";
} // namespace ort_extensions
8 changes: 5 additions & 3 deletions operators/tokenizer/tokenizer_jsconfig.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <sstream>
#include <string_view>
#include "file_sys.h"
#include "nlohmann/json.hpp"

Expand Down Expand Up @@ -89,7 +90,6 @@ class TokenJsonConfig final {
parse_token(json_config, "eos_token", eos_token_);
parse_token(json_config, "unk_token", unk_token_);


auto pad_iter = json_config.find("pad_token");
if (pad_iter != json_config.end() && pad_iter->is_string()) {
pad_token_ = json_config.value("pad_token", "");
Expand Down Expand Up @@ -241,7 +241,7 @@ class TokenJsonConfig final {
std::string unk_token_;
std::string pad_token_;

std::vector<ort_extensions::AddedToken> added_tokens_;
AddedTokenMap added_tokens_;

static AddedToken ParseAddedToken(const json& token) {
AddedToken added_token;
Expand Down Expand Up @@ -275,7 +275,9 @@ class TokenJsonConfig final {
auto added_tokens = tok_json.find("added_tokens");
if (added_tokens != tok_json.end()) {
for (const auto& token : *added_tokens) {
added_tokens_.emplace_back(ParseAddedToken(token));
auto tok_extended = ParseAddedToken(token);
// insert the token into the unordered_map
added_tokens_.emplace(ustring(tok_extended.content_), tok_extended);
}
}
}
Expand Down
9 changes: 0 additions & 9 deletions shared/api/image_transforms_phi_4.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,15 +540,8 @@ class Phi4VisionProcessor {
// Change image_size and image_mask data type to floating point to align the original Python code
static OrtxStatus AlignOutputs(std::vector<TensorPtr>& img_result) {
assert(img_result.size() == 4);
auto image_sizes = std::move(img_result[1]);
auto image_attention_mask = std::move(img_result[2]);

auto new_image_sizes = std::make_unique<ortc::Tensor<float>>(&CppAllocator::Instance());
auto image_sizes_data = new_image_sizes->Allocate(image_sizes->Shape());
auto image_sizes_raw = reinterpret_cast<const int64_t*>(image_sizes->DataRaw());
for (int64_t i = 0; i < image_sizes->NumberOfElement(); ++i) {
image_sizes_data[i] = static_cast<float>(image_sizes_raw[i]);
}
auto new_image_attention_mask = std::make_unique<ortc::Tensor<float>>(&CppAllocator::Instance());
auto image_attention_mask_data = new_image_attention_mask->Allocate(image_attention_mask->Shape());
auto image_attention_mask_raw = reinterpret_cast<const int64_t*>(image_attention_mask->DataRaw());
Expand Down Expand Up @@ -585,8 +578,6 @@ class Phi4VisionProcessor {
}
}


img_result[1].reset(new_image_sizes.release());
img_result[2].reset(new_image_attention_mask.release());
return {};
}
Expand Down
4 changes: 2 additions & 2 deletions shared/api/runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ class ExecutionPlan {
return {};
}

OrtxStatus Excute(ortc::IAllocator* allocator, TensorArgs& input, TensorLookupTable& ts_lookup_table) const {
OrtxStatus Execute(ortc::IAllocator* allocator, TensorArgs& input, TensorLookupTable& ts_lookup_table) const {
for (auto& op : operations_) {
// add tensor references
auto spec = op->GetInputSpec();
Expand Down Expand Up @@ -501,7 +501,7 @@ class OrtxRunner {
OrtxStatus Run(std::vector<TensorArgs>& input_seq, std::vector<TensorArgs>& output_seq) {
for (size_t i = 0; i < input_seq.size(); ++i) {
auto& input = *(input_seq.begin() + i);
auto status = plan_.Excute(allocator_, input, tensor_lookup_table_);
auto status = plan_.Execute(allocator_, input, tensor_lookup_table_);
if (!status.IsOk()) {
return status;
}
Expand Down
10 changes: 7 additions & 3 deletions shared/api/speech_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,13 @@ OrtxStatus SpeechFeatureExtractor::DoCall(ort_extensions::span<AudioRawData> raw
OrtxStatus Phi4AudioEmbed::AlignOutputs(std::vector<TensorPtr>& audio_result) {
auto ts_embed_size = std::move(audio_result.back());
audio_result.pop_back();
auto new_ts_size = std::make_unique<ortc::Tensor<int64_t>>(&CppAllocator::Instance());
auto new_embed_size_data = new_ts_size->Allocate({ts_embed_size->Shape()[0]});
std::memcpy(new_embed_size_data, ts_embed_size->DataRaw(), ts_embed_size->SizeInBytes());
auto new_ts_size = std::make_unique<ortc::Tensor<float>>(&CppAllocator::Instance());
int64_t audio_count = ts_embed_size->Shape()[0];
auto new_embed_size_data = new_ts_size->Allocate({audio_count});
const int64_t* ts_embed_size_data = reinterpret_cast<const int64_t*>(ts_embed_size->DataRaw());
for (int64_t i = 0; i < audio_count; ++i) {
new_embed_size_data[i] = static_cast<float>(ts_embed_size_data[i]);
}
audio_result.emplace_back(std::move(new_ts_size));
return {};
}
Expand Down
41 changes: 21 additions & 20 deletions test/pp_api_test/test_feature_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ TEST(ExtractorTest, TestPhi4AudioFeatureExtraction) {
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&data), &shape, &num_dims);
ASSERT_EQ(num_dims, 1);
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3}));
ASSERT_EQ(std::vector<int64_t>(reinterpret_cast<const int64_t*>(data),
reinterpret_cast<const int64_t*>(data) + 3),
std::vector<int64_t>({138, 167, 168}));
const float* actual_output = reinterpret_cast<const float*>(data);
ASSERT_FLOAT_EQ(actual_output[0], 138.0f);
ASSERT_FLOAT_EQ(actual_output[1], 167.0f);
ASSERT_FLOAT_EQ(actual_output[2], 168.0f);
}

TEST(ExtractorTest, TestPhi4AudioFeatureExtraction8k) {
Expand Down Expand Up @@ -138,7 +139,7 @@ TEST(ExtractorTest, TestPhi4AudioOutput) {

// Define lambda for comparison
auto are_close = [](float a, float b, float rtol = 1e-03, float atol = 1e-02) -> bool {
return std::abs(a - b) <= atol || std::abs(a - b) <= rtol * std::abs(b);
return std::abs(a - b) <= atol || std::abs(a - b) <= rtol * std::abs(b);
};

size_t num_mismatched = 0;
Expand All @@ -147,23 +148,23 @@ TEST(ExtractorTest, TestPhi4AudioOutput) {
size_t row_idx = 0;

while (std::getline(expected_audio_embed_output, line) && row_idx < num_rows) {
std::stringstream ss(line); // Stringstream to parse each line
std::string value_str;
size_t col_idx = 0;

while (std::getline(ss, value_str, ',') && col_idx < 10) { // Only read the first 10 columns
float expected_value = std::stof(value_str); // Convert string to float

// Compare values
const float* row_start = data + (row_idx * num_columns);
if (!are_close(row_start[col_idx], expected_value)) {
num_mismatched++; // Count mismatches
std::cout << "Mismatch at (" << row_idx << "," << col_idx << "): "
<< "Expected: " << expected_value << ", Got: " << row_start[col_idx] << std::endl;
}
col_idx++;
std::stringstream ss(line); // Stringstream to parse each line
std::string value_str;
size_t col_idx = 0;

while (std::getline(ss, value_str, ',') && col_idx < 10) { // Only read the first 10 columns
float expected_value = std::stof(value_str); // Convert string to float

// Compare values
const float* row_start = data + (row_idx * num_columns);
if (!are_close(row_start[col_idx], expected_value)) {
num_mismatched++; // Count mismatches
std::cout << "Mismatch at (" << row_idx << "," << col_idx << "): "
<< "Expected: " << expected_value << ", Got: " << row_start[col_idx] << std::endl;
}
row_idx++;
col_idx++;
}
row_idx++;
}

expected_audio_embed_output.close();
Expand Down
14 changes: 12 additions & 2 deletions test/pp_api_test/test_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,30 +140,40 @@ TEST(ProcessorTest, TestPhi4VisionProcessor) {
ASSERT_EQ(err, kOrtxOK);

OrtxObjectPtr<OrtxTensor> tensor;
// embedding data (float32)
err = OrtxTensorResultGetAt(result.get(), 0, tensor.ToBeAssigned());
ASSERT_EQ(err, kOrtxOK);
const float* data{};
const int64_t* int_data{};
const int64_t* shape{};
size_t num_dims;
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&data), &shape, &num_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3, 10, 3, 448, 448}));
EXPECT_TRUE((data[0] > -0.30f) && (data[0] < -0.29f));

// image sizes (int64_t)
err = OrtxTensorResultGetAt(result.get(), 1, tensor.ToBeAssigned());
ASSERT_EQ(err, kOrtxOK);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&data), &shape, &num_dims);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&int_data), &shape, &num_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3, 2}));
EXPECT_EQ(std::vector<int64_t>(int_data, int_data + 6),
std::vector<int64_t>({1344, 1344, 896, 1344, 448, 896}));

// mask data (float32)
err = OrtxTensorResultGetAt(result.get(), 2, tensor.ToBeAssigned());
ASSERT_EQ(err, kOrtxOK);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&data), &shape, &num_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3, 10, 32, 32}));
EXPECT_FLOAT_EQ(data[0], 1.0f);

// num tokens (int64_t)
err = OrtxTensorResultGetAt(result.get(), 3, tensor.ToBeAssigned());
ASSERT_EQ(err, kOrtxOK);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&data), &shape, &num_dims);
err = OrtxGetTensorData(tensor.get(), reinterpret_cast<const void**>(&int_data), &shape, &num_dims);
ASSERT_EQ(err, kOrtxOK);
ASSERT_EQ(std::vector<int64_t>(shape, shape + num_dims), std::vector<int64_t>({3, 1}));
EXPECT_EQ(std::vector<int64_t>(int_data, int_data + 3), std::vector<int64_t>({2625, 1841, 735}));
}
4 changes: 2 additions & 2 deletions test/pp_api_test/test_tokenizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ TEST(OrtxTokenizerTest, Phi3Tokenizer) {
std::vector<std::string_view> input = {
"分析",
" こんにちは", // an extra space at the beginning
"<|user|>こんにちは。データ分析するにはなにをすればいい?<|end|><|assistant|>"};
"<|user|>\nこんにちは。データ分析するにはなにをすればいい?<|end|><|assistant|>"};
std::vector<extTokenId_t> EXPECTED_IDS_0 = {1, 29871, 30748, 233, 161, 147};
std::vector<extTokenId_t> EXPECTED_IDS_1 = {1, 259, 30589, 30389, 30353, 30644, 30449};
std::vector<extTokenId_t> EXPECTED_IDS_2 = {
1, 32010, 29871, 30589, 30389, 30353, 30644, 30449, 30267, 30597, 30185, 30369, 30748, 233, 161, 147,
1, 32010, 29871, 13, 30589, 30389, 30353, 30644, 30449, 30267, 30597, 30185, 30369, 30748, 233, 161, 147,
30427, 30332, 30353, 30449, 30371, 30353, 30396, 30427, 30553, 31254, 30298, 30298, 30882, 32007, 32001};

std::vector<std::vector<extTokenId_t>> token_ids;
Expand Down
Loading

0 comments on commit 48176ac

Please sign in to comment.