Skip to content

Commit

Permalink
Fix C++ linter issues and audio sizes dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
kunal-vaishnavi committed Feb 26, 2025
1 parent 91cb451 commit f228cff
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6a
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;8ea1b6e0ece370af0121d081075e6ac20582b85c
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;e441b4052ac41f164fb381ef318171a48a09c115
23 changes: 9 additions & 14 deletions src/models/multi_modal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@ int64_t GetNumAudioTokens(const std::vector<GeneratorParams::Input>& extra_input
assert(extra_inputs[i].tensor->ort_tensor_);
auto type_and_shape_info = extra_inputs[i].tensor->ort_tensor_->GetTensorTypeAndShapeInfo();
const auto element_count = type_and_shape_info->GetElementCount();
if (type_and_shape_info->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
const float* audio_sizes_data = extra_inputs[i].tensor->ort_tensor_->GetTensorData<float>();
if (type_and_shape_info->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
const int64_t* audio_sizes_data = extra_inputs[i].tensor->ort_tensor_->GetTensorData<int64_t>();
return std::accumulate(audio_sizes_data, audio_sizes_data + element_count, 0LL, [](int64_t a, float b) {
return a + static_cast<int64_t>(b + 0.5f);
});
} else if (type_and_shape_info->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
const uint16_t* audio_sizes_data = extra_inputs[i].tensor->ort_tensor_->GetTensorData<uint16_t>();
return std::accumulate(audio_sizes_data, audio_sizes_data + element_count, 0LL, [](int64_t a, uint16_t b) {
return a + static_cast<int64_t>(FastFloat16ToFloat32(b) + 0.5f);
});
} else {
throw std::runtime_error("Unsupported data type for audio_sizes tensor.");
throw std::runtime_error("Unsupported data type " + std::to_string(static_cast<int64_t>(type_and_shape_info->GetElementType())) + " for audio_sizes tensor. Only int64 is supported.");
}
}
}
Expand Down Expand Up @@ -180,12 +175,12 @@ MultiModalPipelineState::MultiModalPipelineState(const MultiModalLanguageModel&

if (vision_state_ != nullptr && model_.config_->model.vision.adapter_filename.has_value() && num_image_tokens_ > 0) {
const auto lora_adapter = (model_.config_->config_path / fs::path(*model_.config_->model.vision.adapter_filename));
std::string lora_adapter_str = lora_adapter.string(); // Returns UTF-8 encoded string on Windows
std::string lora_adapter_str = lora_adapter.string(); // Returns UTF-8 encoded string on Windows
adapters_->LoadAdapter(lora_adapter_str.c_str(), vision_adapter_name_);
decoder_state_->SetActiveAdapter(adapters_.get(), vision_adapter_name_);
} else if (speech_state_ != nullptr && model_.config_->model.speech.adapter_filename.has_value() && num_audio_tokens_ > 0) {
const auto lora_adapter = (model_.config_->config_path / fs::path(*model_.config_->model.speech.adapter_filename));
std::string lora_adapter_str = lora_adapter.string(); // Returns UTF-8 encoded string on Windows
std::string lora_adapter_str = lora_adapter.string(); // Returns UTF-8 encoded string on Windows
adapters_->LoadAdapter(lora_adapter_str.c_str(), speech_adapter_name_);
decoder_state_->SetActiveAdapter(adapters_.get(), speech_adapter_name_);
}
Expand All @@ -206,14 +201,14 @@ DeviceSpan<float> MultiModalPipelineState::Run(int current_length, DeviceSpan<in
decoder_state_->UpdateInputsOutputs(next_tokens, current_length, next_indices);

if (is_prompt_) {
if (num_image_tokens_ > 0 && vision_state_ ) {
if (num_image_tokens_ > 0 && vision_state_) {
vision_state_->Run(current_length, next_tokens, next_indices);
}
if (num_audio_tokens_ > 0 && speech_state_ ) {
if (num_audio_tokens_ > 0 && speech_state_) {
speech_state_->Run(current_length, next_tokens, next_indices);
}
if (vision_state_ ) embedding_state_->image_features_.ReuseFeaturesBuffer(vision_state_->image_features_);
if (speech_state_ ) embedding_state_->audio_features_.ReuseFeaturesBuffer(speech_state_->audio_features_);
if (vision_state_) embedding_state_->image_features_.ReuseFeaturesBuffer(vision_state_->image_features_);
if (speech_state_) embedding_state_->audio_features_.ReuseFeaturesBuffer(speech_state_->audio_features_);
embedding_state_->inputs_embeds_.ReuseEmbeddingsBuffer(decoder_state_->inputs_embeds_);
embedding_state_->Run(current_length, next_tokens, next_indices);

Expand Down

0 comments on commit f228cff

Please sign in to comment.