Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Core functionality for Lora Adapaters #679

Closed
wants to merge 72 commits into from
Closed
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
28e86c6
Lora begins
yuslepukhin Jun 26, 2024
52fa32e
Rework span
yuslepukhin Jun 27, 2024
fb0670e
Lora begins
yuslepukhin Jun 26, 2024
beb630f
Start public API
yuslepukhin Jun 27, 2024
d006037
Merge branch 'yuslepukhin/implement_lora_adapters' of https://github.…
yuslepukhin Jun 27, 2024
5ad527b
Add param
yuslepukhin Jun 28, 2024
22d65cb
Add C++ API
yuslepukhin Jun 28, 2024
94e2c83
Add LoraManagement unit tests
yuslepukhin Jul 1, 2024
ad00a14
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 1, 2024
0a712a9
More tests
yuslepukhin Jul 1, 2024
11dc3f2
Add DeactiveAdapters
yuslepukhin Jul 2, 2024
7a98784
Add C and c++ API test
yuslepukhin Jul 2, 2024
db9b5fa
Adjust ExtraInputs.
yuslepukhin Jul 2, 2024
ed29925
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 3, 2024
dbdc369
Add thread-safety, setup for potential caching
yuslepukhin Jul 3, 2024
fa42f95
Added device copy to LoraAdapter
yuslepukhin Jul 3, 2024
a104b20
Add input addition verification
yuslepukhin Jul 3, 2024
7348901
Introduce utilities
yuslepukhin Jul 4, 2024
9164a96
Merge branch 'yuslepukhin/implement_lora_adapters' of https://github.…
yuslepukhin Jul 5, 2024
be1e6b0
Add automatic span construction from an array
yuslepukhin Jul 5, 2024
754d3e7
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 5, 2024
52b8aa5
Fix warnings
yuslepukhin Jul 5, 2024
1247461
Code issues
yuslepukhin Jul 5, 2024
40d939a
Flatbuffers begin
yuslepukhin Jul 10, 2024
89c00b3
Make flatbuffers tests run
yuslepukhin Jul 11, 2024
92bc527
Add python save_array_as_lora_parameter
yuslepukhin Jul 11, 2024
4e43b51
Add python test and readback in CXX
yuslepukhin Jul 11, 2024
24dc3c6
Add assert and cleanup
yuslepukhin Jul 12, 2024
297f894
Fix test issues
yuslepukhin Jul 12, 2024
f2b42bc
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 12, 2024
4513168
Address review comments
yuslepukhin Jul 12, 2024
a100cde
include mutex
yuslepukhin Jul 12, 2024
de5fab0
Build errors
yuslepukhin Jul 15, 2024
ea905fc
Address review comments
yuslepukhin Jul 15, 2024
c21e540
Introduce saving multiple params into the same fb file
yuslepukhin Jul 16, 2024
82a99de
Add convertion utility
yuslepukhin Jul 16, 2024
bc84e1d
Add import for lora helpers
yuslepukhin Jul 16, 2024
2fe6e6c
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 16, 2024
455ffdb
Merge branch 'yuslepukhin/implement_lora_adapters' into yuslepukhin/l…
yuslepukhin Jul 16, 2024
abf61b4
Add a tool to modify genai config and add adapters section
yuslepukhin Jul 16, 2024
5e58286
Work on config driven load
yuslepukhin Jul 17, 2024
41e3136
Fix up CUDA build
yuslepukhin Jul 18, 2024
eeba0c4
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 18, 2024
9cdd53f
Merge branch 'yuslepukhin/lora_params_ondisk' into yuslepukhin/implem…
yuslepukhin Jul 18, 2024
7905942
Fix merge
yuslepukhin Jul 18, 2024
d0f1346
Address security warnings
yuslepukhin Jul 18, 2024
07617c0
Fix stray include
yuslepukhin Jul 18, 2024
7dd7274
Address build shortcomings
yuslepukhin Jul 18, 2024
2992140
Clang format
yuslepukhin Jul 18, 2024
c77fb44
Clag format
yuslepukhin Jul 18, 2024
287693f
Remove redundant methods
yuslepukhin Jul 19, 2024
f8de77e
Run test coverage, remove some dead code. Cover base case.
yuslepukhin Jul 19, 2024
79bad70
Add missing checks
yuslepukhin Jul 19, 2024
08059ff
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 19, 2024
6a6cc3b
Address review comments
yuslepukhin Jul 22, 2024
e0088e3
Address build issues, refresh the test model
yuslepukhin Jul 22, 2024
4252975
Adjust file paths
yuslepukhin Jul 23, 2024
8814b47
Make it work end to end
yuslepukhin Jul 23, 2024
175ea54
Make FlatBuffers linkage public
yuslepukhin Jul 23, 2024
0bc4e3f
Move new test subfolder
yuslepukhin Jul 23, 2024
a77ab66
Add model
yuslepukhin Jul 23, 2024
ee6dcb9
Add fp16 model
yuslepukhin Jul 23, 2024
67bebd2
Adjust src
yuslepukhin Jul 23, 2024
57d4ef1
Adjust for ARM
yuslepukhin Jul 23, 2024
2318427
Create separate model copy and config to run on DML
yuslepukhin Jul 23, 2024
b902b70
Disable DML
yuslepukhin Jul 24, 2024
33a3ec2
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 24, 2024
5dc65d7
Clang format
yuslepukhin Jul 24, 2024
c765a16
Address python related comments, correct faulty formatting for public…
yuslepukhin Jul 25, 2024
36f569c
Remove redundant linkage and includes
yuslepukhin Jul 25, 2024
3526f91
Rename python interface
yuslepukhin Jul 25, 2024
b936657
Correct function name
yuslepukhin Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ BasedOnStyle: Google

# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained.
# Developers are responsible for adhering to the 120 character maximum.
ColumnLimit: 0
ColumnLimit: 120
SortIncludes: Never
DerivePointerAlignment: false

Expand Down
12 changes: 8 additions & 4 deletions src/models/captured_graph_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,
// Multiple generators can reserve graphs in parallel, so we need to make it thread saf
std::unique_lock lock(captured_graph_mutex_);

auto key = std::make_unique<CapturedGraphKey>(params.max_batch_size, params.search.max_length, params.search.num_beams, params.extra_inputs);
auto key = std::make_unique<CapturedGraphKey>(params.max_batch_size, params.search.max_length,
params.search.num_beams, params.extra_inputs);
auto& captured_graphs = captured_graphs_map_[*key];

// If no graphs are available, create a graph with a new ID
Expand Down Expand Up @@ -59,7 +60,8 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,
new_captured_graph->sb_kv_caches_.reserve(layer_count * 2);

for (int i = 0; i < layer_count * 2; ++i) {
new_captured_graph->sb_kv_caches_.push_back(std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size));
new_captured_graph->sb_kv_caches_.push_back(
std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size));
}

// Create the static buffer for the position ids, if needed
Expand All @@ -74,7 +76,8 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,
#if USE_DML
// DML currently needs an additional static buffer for the mask
if (model.device_type_ == DeviceType::DML) {
new_captured_graph->sb_attention_mask_next_ = std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
new_captured_graph->sb_attention_mask_next_ =
std::make_unique<StaticBuffer>(allocator_device_, max_beam_batch_size);
}
#endif
}
Expand All @@ -92,7 +95,8 @@ CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model,
// Create the extra inputs
for (const auto& extra_input : params.extra_inputs) {
auto first_dim = extra_input.tensor->ort_tensor_->GetTensorTypeAndShapeInfo()->GetShape()[0];
new_captured_graph->sb_extra_inputs_[extra_input.name] = std::make_unique<StaticBuffer>(allocator_device_, first_dim);
new_captured_graph->sb_extra_inputs_[extra_input.name] =
std::make_unique<StaticBuffer>(allocator_device_, first_dim);
}

// Create the input embeddings if needed
Expand Down
18 changes: 9 additions & 9 deletions src/models/debugging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,47 +79,47 @@ void DumpValues(std::ostream& stream, ONNXTensorElementDataType type, const void
break;

case Ort::TypeToTensorType<int8_t>::type:
DumpSpan(stream, std::span<const int8_t>{reinterpret_cast<const int8_t*>(p_values_raw), count});
DumpSpan(stream, std::span<const int8_t>(reinterpret_cast<const int8_t*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<uint8_t>::type:
DumpSpan(stream, std::span<const uint8_t>{reinterpret_cast<const uint8_t*>(p_values_raw), count});
DumpSpan(stream, std::span<const uint8_t>(reinterpret_cast<const uint8_t*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<int16_t>::type:
DumpSpan(stream, std::span<const int16_t>{reinterpret_cast<const int16_t*>(p_values_raw), count});
DumpSpan(stream, std::span<const int16_t>(reinterpret_cast<const int16_t*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<uint16_t>::type:
DumpSpan(stream, std::span<const uint16_t>{reinterpret_cast<const uint16_t*>(p_values_raw), count});
DumpSpan(stream, std::span<const uint16_t>(reinterpret_cast<const uint16_t*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<int32_t>::type:
DumpSpan(stream, std::span<const int32_t>{reinterpret_cast<const int32_t*>(p_values_raw), count});
break;

case Ort::TypeToTensorType<uint32_t>::type:
DumpSpan(stream, std::span<const uint32_t>{reinterpret_cast<const uint32_t*>(p_values_raw), count});
DumpSpan(stream, std::span<const uint32_t>(reinterpret_cast<const uint32_t*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<int64_t>::type:
DumpSpan(stream, std::span<const int64_t>{reinterpret_cast<const int64_t*>(p_values_raw), count});
DumpSpan(stream, std::span<const int64_t>(reinterpret_cast<const int64_t*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<uint64_t>::type:
DumpSpan(stream, std::span<const uint64_t>{reinterpret_cast<const uint64_t*>(p_values_raw), count});
break;

case Ort::TypeToTensorType<Ort::Float16_t>::type:
DumpSpan(stream, std::span<const Ort::Float16_t>{reinterpret_cast<const Ort::Float16_t*>(p_values_raw), count});
DumpSpan(stream, std::span<const Ort::Float16_t>(reinterpret_cast<const Ort::Float16_t*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<float>::type:
DumpSpan(stream, std::span<const float>{reinterpret_cast<const float*>(p_values_raw), count});
DumpSpan(stream, std::span<const float>(reinterpret_cast<const float*>(p_values_raw), count));
break;

case Ort::TypeToTensorType<double>::type:
DumpSpan(stream, std::span<const double>{reinterpret_cast<const double*>(p_values_raw), count});
DumpSpan(stream, std::span<const double>(reinterpret_cast<const double*>(p_values_raw), count));
break;

default:
Expand Down
92 changes: 29 additions & 63 deletions src/models/extra_inputs.cpp
Original file line number Diff line number Diff line change
@@ -1,86 +1,52 @@
#include "../generators.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "lora_adapter.h"
#include "model.h"
#include "extra_inputs.h"
#include "kernels.h"

namespace Generators {

ExtraInputs::ExtraInputs(const Model& model, State& state)
: model_{model},
state_{state} {
extra_inputs_.reserve(state_.params_->extra_inputs.size());
ExtraInputs::ExtraInputs(const Model& model, State& state) : model_{model}, state_{state} {
auto& lora_management = model_.GetLoraAdapterManagement();
const auto total_inputs = state_.params_->extra_inputs.size() + lora_management.GetParamNum();
extra_input_names_.reserve(total_inputs);
extra_inputs_.reserve(total_inputs);

if (state_.GetCapturedGraphInfo()) {
owned_extra_inputs_.reserve(state_.params_->extra_inputs.size());

for (int i = 0; i < state_.params_->extra_inputs.size(); ++i) {
auto type_and_shape_info = state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorTypeAndShapeInfo();
const auto& input_name = state_.params_->extra_inputs[i].name;

sb_extra_inputs_.emplace(input_name, state_.GetCapturedGraphInfo()->sb_extra_inputs_.at(input_name).get());
owned_extra_inputs_.push_back(sb_extra_inputs_.at(input_name)->CreateTensorOnStaticBuffer(type_and_shape_info->GetShape(), type_and_shape_info->GetElementType()));
extra_inputs_.push_back(owned_extra_inputs_.back().get());
auto* sb_extra = state_.GetCapturedGraphInfo()->sb_extra_inputs_.at(input_name).get();
auto ort_value =
sb_extra->CreateTensorOnStaticBuffer(type_and_shape_info->GetShape(), type_and_shape_info->GetElementType());

// Copy to value created on top of the StaticBuffer
CopyToDevice(model_, *state_.params_->extra_inputs[1].tensor->ort_tensor_, *ort_value);

extra_input_names_.push_back(input_name);
extra_inputs_.push_back(std::move(ort_value));
}
} else {
// We don't use graph capture, so simply use the existing pointers
// We don't use graph capture
for (auto& extra_input : state_.params_->extra_inputs) {
extra_inputs_.push_back(extra_input.tensor->ort_tensor_.get());
extra_input_names_.push_back(extra_input.name);
auto ort_value = DuplicateOrtValue(*extra_input.tensor->ort_tensor_);
extra_inputs_.push_back(std::move(ort_value));
}
}
}

#pragma warning(push)
#pragma warning(disable : 4065) // switch statement contains 'default' but no 'case' labels
#pragma warning(disable : 4189) // local variable is initialized but not referenced
#pragma warning(disable : 4702) // unreachable code
// Add Lora Parameters
lora_management.OutputAdaptersParameters(std::back_inserter(extra_input_names_), std::back_inserter(extra_inputs_));
}

void ExtraInputs::Add() {
// Add extra user inputs
for (int i = 0; i < state_.params_->extra_inputs.size(); ++i) {
state_.input_names_.push_back(state_.params_->extra_inputs[i].name.c_str());
state_.inputs_.push_back(extra_inputs_[i]);
}

// Copy the data from the CPU-backed ORT value to the static buffers
for (int i = 0; i < sb_extra_inputs_.size(); ++i) {
auto type_and_shape_info = extra_inputs_[i]->GetTensorTypeAndShapeInfo();
auto shape = type_and_shape_info->GetShape();
auto element_count = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
auto copy_size_in_bytes = element_count * SizeOf(type_and_shape_info->GetElementType());

switch (model_.device_type_) {
#if USE_DML
case DeviceType::DML: {
ComPtr<ID3D12Resource> target_resource;
Ort::ThrowOnError(model_.GetOrtDmlApi()->GetD3D12ResourceFromAllocation(model_.allocator_device_, extra_inputs_[i]->GetTensorMutableRawData(), &target_resource));

auto source = std::span(state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorData<const uint8_t>(), copy_size_in_bytes);

model_.GetDmlUploadHeap()->BeginUploadToGpu(
target_resource.Get(),
0,
D3D12_RESOURCE_STATE_UNORDERED_ACCESS,
source);
} break;
#endif

#if USE_CUDA
case DeviceType::CUDA: {
cudaMemcpyAsync(
extra_inputs_[i]->GetTensorMutableRawData(),
state_.params_->extra_inputs[i].tensor->ort_tensor_->GetTensorMutableRawData(),
copy_size_in_bytes,
cudaMemcpyHostToDevice,
model_.cuda_stream_);
} break;
#endif

default:
throw std::runtime_error("Unsupported device for graph capture");
}
// Add extra user inputs to the state
for (size_t i = 0, lim = extra_input_names_.size(); i < lim; ++i) {
state_.input_names_.push_back(extra_input_names_[i].c_str());
state_.inputs_.push_back(extra_inputs_[i].get());
}
}

#pragma warning(pop)

} // namespace Generators
20 changes: 16 additions & 4 deletions src/models/extra_inputs.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "static_buffer.h"
#include "onnxruntime_api.h"

#include <memory>
#include <string>
#include <vector>

namespace Generators {

struct Model;
struct State;

struct ExtraInputs {
ExtraInputs(const Model& model, State& state);
ExtraInputs(const ExtraInputs&) = delete;
ExtraInputs& operator=(const ExtraInputs&) = delete;

void Add();

private:
const Model& model_;
State& state_;
std::vector<OrtValue*> extra_inputs_;
std::vector<std::unique_ptr<OrtValue>> owned_extra_inputs_;
std::unordered_map<std::string, StaticBuffer*> sb_extra_inputs_;
std::vector<std::string> extra_input_names_;
std::vector<std::shared_ptr<OrtValue>> extra_inputs_;
};

} // namespace Generators
Loading
Loading