Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Core functionality for Lora Adapaters #679

Closed
wants to merge 72 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
28e86c6
Lora begins
yuslepukhin Jun 26, 2024
52fa32e
Rework span
yuslepukhin Jun 27, 2024
fb0670e
Lora begins
yuslepukhin Jun 26, 2024
beb630f
Start public API
yuslepukhin Jun 27, 2024
d006037
Merge branch 'yuslepukhin/implement_lora_adapters' of https://github.…
yuslepukhin Jun 27, 2024
5ad527b
Add param
yuslepukhin Jun 28, 2024
22d65cb
Add C++ API
yuslepukhin Jun 28, 2024
94e2c83
Add LoraManagement unit tests
yuslepukhin Jul 1, 2024
ad00a14
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 1, 2024
0a712a9
More tests
yuslepukhin Jul 1, 2024
11dc3f2
Add DeactiveAdapters
yuslepukhin Jul 2, 2024
7a98784
Add C and c++ API test
yuslepukhin Jul 2, 2024
db9b5fa
Adjust ExtraInputs.
yuslepukhin Jul 2, 2024
ed29925
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 3, 2024
dbdc369
Add thread-safety, setup for potential caching
yuslepukhin Jul 3, 2024
fa42f95
Added device copy to LoraAdapter
yuslepukhin Jul 3, 2024
a104b20
Add input addition verification
yuslepukhin Jul 3, 2024
7348901
Introduce utilities
yuslepukhin Jul 4, 2024
9164a96
Merge branch 'yuslepukhin/implement_lora_adapters' of https://github.…
yuslepukhin Jul 5, 2024
be1e6b0
Add automatic span construction from an array
yuslepukhin Jul 5, 2024
754d3e7
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 5, 2024
52b8aa5
Fix warnings
yuslepukhin Jul 5, 2024
1247461
Code issues
yuslepukhin Jul 5, 2024
40d939a
Flatbuffers begin
yuslepukhin Jul 10, 2024
89c00b3
Make flatbuffers tests run
yuslepukhin Jul 11, 2024
92bc527
Add python save_array_as_lora_parameter
yuslepukhin Jul 11, 2024
4e43b51
Add python test and readback in CXX
yuslepukhin Jul 11, 2024
24dc3c6
Add assert and cleanup
yuslepukhin Jul 12, 2024
297f894
Fix test issues
yuslepukhin Jul 12, 2024
f2b42bc
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 12, 2024
4513168
Address review comments
yuslepukhin Jul 12, 2024
a100cde
include mutex
yuslepukhin Jul 12, 2024
de5fab0
Build errors
yuslepukhin Jul 15, 2024
ea905fc
Address review comments
yuslepukhin Jul 15, 2024
c21e540
Introduce saving multiple params into the same fb file
yuslepukhin Jul 16, 2024
82a99de
Add convertion utility
yuslepukhin Jul 16, 2024
bc84e1d
Add import for lora helpers
yuslepukhin Jul 16, 2024
2fe6e6c
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 16, 2024
455ffdb
Merge branch 'yuslepukhin/implement_lora_adapters' into yuslepukhin/l…
yuslepukhin Jul 16, 2024
abf61b4
Add a tool to modify genai config and add adapters section
yuslepukhin Jul 16, 2024
5e58286
Work on config driven load
yuslepukhin Jul 17, 2024
41e3136
Fix up CUDA build
yuslepukhin Jul 18, 2024
eeba0c4
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 18, 2024
9cdd53f
Merge branch 'yuslepukhin/lora_params_ondisk' into yuslepukhin/implem…
yuslepukhin Jul 18, 2024
7905942
Fix merge
yuslepukhin Jul 18, 2024
d0f1346
Address security warnings
yuslepukhin Jul 18, 2024
07617c0
Fix stray include
yuslepukhin Jul 18, 2024
7dd7274
Address build shortcomings
yuslepukhin Jul 18, 2024
2992140
Clang format
yuslepukhin Jul 18, 2024
c77fb44
Clag format
yuslepukhin Jul 18, 2024
287693f
Remove redundant methods
yuslepukhin Jul 19, 2024
f8de77e
Run test coverage, remove some dead code. Cover base case.
yuslepukhin Jul 19, 2024
79bad70
Add missing checks
yuslepukhin Jul 19, 2024
08059ff
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 19, 2024
6a6cc3b
Address review comments
yuslepukhin Jul 22, 2024
e0088e3
Address build issues, refresh the test model
yuslepukhin Jul 22, 2024
4252975
Adjust file paths
yuslepukhin Jul 23, 2024
8814b47
Make it work end to end
yuslepukhin Jul 23, 2024
175ea54
Make FlatBuffers linkage public
yuslepukhin Jul 23, 2024
0bc4e3f
Move new test subfolder
yuslepukhin Jul 23, 2024
a77ab66
Add model
yuslepukhin Jul 23, 2024
ee6dcb9
Add fp16 model
yuslepukhin Jul 23, 2024
67bebd2
Adjust src
yuslepukhin Jul 23, 2024
57d4ef1
Adjust for ARM
yuslepukhin Jul 23, 2024
2318427
Create separate model copy and config to run on DML
yuslepukhin Jul 23, 2024
b902b70
Disable DML
yuslepukhin Jul 24, 2024
33a3ec2
Merge branch 'main' into yuslepukhin/implement_lora_adapters
yuslepukhin Jul 24, 2024
5dc65d7
Clang format
yuslepukhin Jul 24, 2024
c765a16
Address python related comments, correct faulty formatting for public…
yuslepukhin Jul 25, 2024
36f569c
Remove redundant linkage and includes
yuslepukhin Jul 25, 2024
3526f91
Rename python interface
yuslepukhin Jul 25, 2024
b936657
Correct function name
yuslepukhin Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,26 @@ if(MSVC)
add_compile_definitions(_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR)
endif()

find_package(Patch)
if (WIN32 AND NOT Patch_FOUND)
# work around CI machines missing patch from the git install by falling back to the binary in this repo.
# replicate what happens in https://github.com/Kitware/CMake/blob/master/Modules/FindPatch.cmake but without
# the hardcoded suffixes in the path to the patch binary.
find_program(Patch_EXECUTABLE NAMES patch PATHS ${PROJECT_SOURCE_DIR}/external/git.Win32.2.41.03.patch)
if(Patch_EXECUTABLE)
set(Patch_FOUND 1)
if (NOT TARGET Patch::patch)
add_executable(Patch::patch IMPORTED)
set_property(TARGET Patch::patch PROPERTY IMPORTED_LOCATION ${Patch_EXECUTABLE})
endif()
endif()
endif()
if(Patch_FOUND)
message("Patch found: ${Patch_EXECUTABLE}")
endif()

include(cmake/external/onnxruntime_external_deps.cmake)

# All Global variables, including GLOB, for the top level CMakeLists.txt should be defined here
include(cmake/global_variables.cmake)
# Checking if CUDA is supported
Expand All @@ -32,6 +51,8 @@ include(cmake/check_dml.cmake)

include(cmake/cxx_standard.cmake)

include(cmake/genai_flatbuffers.cmake)

add_compile_definitions(BUILDING_ORT_GENAI_C)
if(MSVC)
# set updated value for __cplusplus macro instead of 199711L
Expand Down Expand Up @@ -72,6 +93,10 @@ else()
add_library(onnxruntime-genai-static STATIC ${generator_srcs})
endif()

target_link_libraries(onnxruntime-genai PRIVATE genai_flatbuffers)
target_link_libraries(onnxruntime-genai-static PRIVATE genai_flatbuffers)


target_include_directories(onnxruntime-genai PRIVATE ${ORT_HEADER_DIR})
target_include_directories(onnxruntime-genai-static PRIVATE ${ORT_HEADER_DIR})
target_include_directories(onnxruntime-genai PRIVATE ${onnxruntime_extensions_SOURCE_DIR}/include)
Expand Down
1 change: 1 addition & 0 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#not affect built binaries.
#
# NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI.
flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip;59422c3b5e573dd192fead2834d25951f1c1670c
pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.10.1.zip;769b6aa67a77f17a770960f604b727645b6f6a13
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
Expand Down
Binary file not shown.
Binary file not shown.
Binary file added cmake/external/git.Win32.2.41.03.patch/patch.exe
Binary file not shown.
32 changes: 32 additions & 0 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,38 @@ endforeach()

message("Loading Dependencies ...")


# Flatbuffers
# We do not need to build flatc for iOS or Android Cross Compile
if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE)
endif()
set(FLATBUFFERS_BUILD_TESTS OFF CACHE BOOL "FLATBUFFERS_BUILD_TESTS" FORCE)
set(FLATBUFFERS_INSTALL OFF CACHE BOOL "FLATBUFFERS_INSTALL" FORCE)
set(FLATBUFFERS_BUILD_FLATHASH OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATHASH" FORCE)
set(FLATBUFFERS_BUILD_FLATLIB ON CACHE BOOL "FLATBUFFERS_BUILD_FLATLIB" FORCE)

if(NOT WIN32)
if(Patch_FOUND)
set(GENAI_FLATBUFFERS_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 <
${CMAKE_SOURCE_DIR}/cmake/patches/flatbuffers/flatbuffers.patch)
else()
set(GENAI_FLATBUFFERS_PATCH_COMMAND "")
endif()
else()
set(GENAI_FLATBUFFERS_PATCH_COMMAND "")
endif()

FetchContent_Declare(
flatbuffers
URL ${DEP_URL_flatbuffers}
URL_HASH SHA1=${DEP_SHA1_flatbuffers}
PATCH_COMMAND ${GENAI_FLATBUFFERS_PATCH_COMMAND}
FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers
)

onnxruntime_fetchcontent_makeavailable(flatbuffers)

if(ENABLE_PYTHON)
FetchContent_Declare(
pybind11_project
Expand Down
19 changes: 19 additions & 0 deletions cmake/genai_flatbuffers.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

file(GLOB genai_flatbuffers_srcs CONFIGURE_DEPENDS
"${CMAKE_SOURCE_DIR}/src/flatbuffers/*.h"
"${CMAKE_SOURCE_DIR}/src/flatbuffers/*.cc"
)

add_library(genai_flatbuffers STATIC ${genai_flatbuffers_srcs})
target_link_libraries(genai_flatbuffers PUBLIC FlatBuffers::FlatBuffers)

target_include_directories(genai_flatbuffers PRIVATE ${ORT_HEADER_DIR})
target_link_directories(genai_flatbuffers PRIVATE ${ORT_LIB_DIR})

# Add dependency so the flatbuffers compiler is built if enabled
if (FLATBUFFERS_BUILD_FLATC)
add_dependencies(genai_flatbuffers flatc)
endif()

12 changes: 12 additions & 0 deletions cmake/patches/flatbuffers/flatbuffers.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 3987eac9..5e5462f1 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -279,5 +279,5 @@
# Append FLATBUFFERS_CXX_FLAGS to CMAKE_CXX_FLAGS.
if(DEFINED FLATBUFFERS_CXX_FLAGS)
message(STATUS "extend CXX_FLAGS with ${FLATBUFFERS_CXX_FLAGS}")
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS}")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLATBUFFERS_CXX_FLAGS} -Wno-error=stringop-overflow")
endif()
message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
30 changes: 30 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "generators.h"
#include "json.h"

#include <fstream>
#include <sstream>

Expand Down Expand Up @@ -460,6 +461,31 @@ struct Search_Element : JSON::Element {
Config::Search& v_;
};

class LoraAdapters_Element : public JSON::Element {
public:
explicit LoraAdapters_Element(Config::LoraAdapters& v) noexcept : v_{v} {}

private:
JSON::Element& OnObject(std::string_view name) override {
if (current_adapter_ != name) {
current_adapter_ = name;
return *this;
}
throw JSON::unknown_value_error{};
}

void OnString(std::string_view name, std::string_view path) override {
if (name == "weights") {
v_.adapters.emplace(current_adapter_, path);
} else {
throw JSON::unknown_value_error{};
}
}

Config::LoraAdapters& v_;
std::string current_adapter_;
};

void SetSearchNumber(Config::Search& search, std::string_view name, double value) {
Search_Element(search).OnNumber(name, value);
}
Expand Down Expand Up @@ -499,12 +525,16 @@ struct Root_Element : JSON::Element {
if (name == "search") {
return search_element_;
}
if (name == "adapters") {
return lora_adapters_element_;
}
throw JSON::unknown_value_error{};
}

Config& config_;
Model_Element model_element_{config_.model};
Search_Element search_element_{config_.search};
LoraAdapters_Element lora_adapters_element_{config_.lora_adapters};
};

struct RootObject_Element : JSON::Element {
Expand Down
5 changes: 5 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ struct Config {
int random_seed{-1}; // -1 = Seed with random device, otherwise use value to seed RNG
} search;

struct LoraAdapters {
// Stores adapter name to file name mapping
std::unordered_map<std::string, std::string> adapters;
} lora_adapters;

void AddMapping(const std::string& nominal_name, const std::string& graph_name);
// Returns graph name and true if the nominal name is found in the mapping
// otherwise returns the nominal name and false
Expand Down
17 changes: 17 additions & 0 deletions src/flatbuffers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#if defined(__GNUC__)
#pragma GCC diagnostic push

#ifdef HAS_SHORTEN_64_TO_32
#pragma GCC diagnostic ignored "-Wshorten-64-to-32"
#endif
#endif

#include "flatbuffers/flatbuffers.h"

#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
63 changes: 63 additions & 0 deletions src/flatbuffers/flatbuffers_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "flatbuffers_utils.h"
#include "schema/genai_lora.fbs.h"
#include "../../src/models/onnxruntime_api.h"

#include "../models/onnxruntime_api.h"

namespace Generators {
namespace lora_parameters {
namespace utils {

bool IsGenAiLoraFormatModelBytes(const void* bytes, size_t num_bytes) {
return num_bytes > 8 && // check buffer is large enough to contain identifier so we don't read random memory
ParametersBufferHasIdentifier(bytes);
}

flatbuffers::Offset<flatbuffers::String> SaveStringToLoraFormat(flatbuffers::FlatBufferBuilder& builder,
bool has_string, const std::string& src) {
if (has_string) return builder.CreateString(src);

// If the string does not exist, return 0 (the string does not exist in flatbuffer)
return 0;
}

void LoadStringFromLoraFormat(std::string& dst, const flatbuffers::String* fbs_string) {
if (fbs_string) {
dst = fbs_string->str();
}
}

void SaveLoraParameter(flatbuffers::FlatBufferBuilder& flat_builder, std::string_view name,
Generators::lora_parameters::TensorDataType data_type, std::span<const int64_t> shape,
std::span<const uint8_t> data,
flatbuffers::Offset<Generators::lora_parameters::Param>& fbs_tensor) {
auto name_str = (name.empty()) ? 0 : flat_builder.CreateString(name.data(), name.size());
auto shape_vec = flat_builder.CreateVector(shape.data(), shape.size());
auto data_vec = flat_builder.CreateVector(data.data(), data.size());

fbs_tensor = CreateParam(flat_builder, name_str, shape_vec, data_type, data_vec);
}

std::pair<std::string, std::unique_ptr<OrtValue>> CreateOrtValueOverFlatBufferLoraParameter(
const Generators::lora_parameters::Param& tensor) {
std::string name;
LoadStringFromLoraFormat(name, tensor.name());

const auto data_type = tensor.data_type();

std::span<const int64_t> shape_span(tensor.dims()->data(), tensor.dims()->size());

auto mem_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
auto ort_value =
OrtValue::CreateTensor(*mem_info, const_cast<uint8_t*>(tensor.raw_data()->data()),
static_cast<size_t>(tensor.raw_data()->size()), shape_span,
static_cast<ONNXTensorElementDataType>(data_type));
return std::make_pair(std::move(name), std::move(ort_value));
}

} // namespace utils
} // namespace lora_parameters
} // namespace Generators
59 changes: 59 additions & 0 deletions src/flatbuffers/flatbuffers_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "../flatbuffers.h"
#include "../span.h"

#include "schema/genai_lora.fbs.h"

#include <string>
#include <string_view>
#include <unordered_map>

struct OrtValue;

namespace Generators {
namespace lora_parameters {
namespace utils {

// Will only create string in flatbuffers when has_string is true
flatbuffers::Offset<flatbuffers::String> SaveStringToLoraFormat(flatbuffers::FlatBufferBuilder& builder,
bool has_string, const std::string& src);

void LoadStringFromLoraFormat(std::string& dst, const flatbuffers::String* fbs_string);

/// <summary>
/// Serializes tensor data into flatbuffer
/// </summary>
/// <param name="flat_builder"></param>
/// <param name="name">parameter name</param>
/// <param name="doc">doc, optional</param>
/// <param name="data_type"></param>
/// <param name="shape"></param>
/// <param name="data"></param>
/// <param name="fbs_tensor">output offset</param>
void SaveLoraParameter(flatbuffers::FlatBufferBuilder& flat_builder, std::string_view name,
Generators::lora_parameters::TensorDataType data_type,
std::span<const int64_t> shape, std::span<const uint8_t> data,
flatbuffers::Offset<Generators::lora_parameters::Param>& fbs_tensor);

/// <summary>
/// Create an OrtValue on top of the flatbuffer tensor
/// No copying of data is done here. The caller is responsible for managing the lifetime of flatbuffer
/// structures.
///
/// In this scenario, one can memory map the entire flatbuffer tensor data into OrtValue without copying.
/// </summary>
/// <param name="tensor"></param>
/// <returns></returns>
std::pair<std::string, std::unique_ptr<OrtValue>> CreateOrtValueOverFlatBufferLoraParameter(
const Generators::lora_parameters::Param& tensor);

// check if bytes has fileidentifier for lora parameters
bool IsGenAiLoraFormatModelBytes(const void* bytes, size_t num_bytes);

} // namespace utils
} // namespace lora_parameters
} // namespace Generators
33 changes: 33 additions & 0 deletions src/flatbuffers/lora_format_version.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <algorithm>
#include <array>

namespace Generators {
namespace lora_parameters {

// The current model versions for saving lora parameters in flatbuffers
// Once this version is updated, the kSupportedLoraFormatVersions in IsGenAiLoraFormatModelBytes
// below will also need to be updated.
// See src/flatbuffers/schema/README.md for more details on versioning.
// Version 1 - history begins
constexpr const int kLoraFormatVersion = 1;

// Check if the given lora format version is supported in this build
inline bool IsLoraFormatVersionSupported(const int lora_format_version) {
// The lora format versions we will support in this build
// This may contain more versions than the kLoraFormatVersion, based on the compatibilities
constexpr std::array<int, 1U> kSupportedLoraFormatVersions{
kLoraFormatVersion,
};

const auto it =
std::find(kSupportedLoraFormatVersions.begin(), kSupportedLoraFormatVersions.end(), lora_format_version);
return it != kSupportedLoraFormatVersions.cend();
}

} // namespace lora_parameters
} // namespace Generators
Loading
Loading