Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for llguidance grammars #10224

Merged
merged 34 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
76290d9
initial porting of previous LLG patch
mmoskal Jan 25, 2025
f19655c
update for new APIs
mmoskal Jan 25, 2025
f4dc4b8
build: integrate llguidance as an external project
mmoskal Jan 25, 2025
afb6cac
use '%llguidance' as marker to enable llg lark syntax
mmoskal Jan 26, 2025
b5399d4
add some docs
mmoskal Jan 26, 2025
adc4aed
clarify docs
mmoskal Jan 26, 2025
2a92bfb
code style fixes
mmoskal Jan 26, 2025
8cb12d4
remove llguidance.h from .gitignore
mmoskal Jan 26, 2025
de269a1
fix tests when llg is enabled
mmoskal Jan 26, 2025
a7be666
pass vocab not model to llama_sampler_init_llg()
mmoskal Jan 26, 2025
3675050
copy test-grammar-integration.cpp to test-llguidance.cpp
mmoskal Jan 26, 2025
58006dd
clang fmt
mmoskal Jan 26, 2025
036b91f
fix ref-count bug
mmoskal Jan 26, 2025
f245ca2
build and run test
mmoskal Jan 26, 2025
16a5484
gbnf -> lark syntax
mmoskal Jan 26, 2025
2937537
conditionally include llguidance test based on LLAMA_LLGUIDANCE flag
mmoskal Jan 26, 2025
c7ebf57
rename llguidance test file to test-grammar-llguidance.cpp
mmoskal Jan 26, 2025
0a211fc
add gh action for llg test
mmoskal Jan 26, 2025
8e027f8
align tests with LLG grammar syntax and JSON Schema spec
mmoskal Jan 26, 2025
ca88ce7
llama_tokenizer() in fact requires valid utf8
mmoskal Jan 26, 2025
44e1973
update llg
mmoskal Jan 26, 2025
c9e9853
format file
mmoskal Jan 26, 2025
efc36c9
add $LLGUIDANCE_LOG_LEVEL support
mmoskal Jan 26, 2025
08fefd1
fix whitespace
mmoskal Jan 26, 2025
1afc53a
fix warning
mmoskal Jan 26, 2025
00fcd98
include <cmath> for INFINITY
mmoskal Jan 26, 2025
437ff31
add final newline
mmoskal Jan 26, 2025
5475357
fail llama_sampler_init_llg() at runtime
mmoskal Jan 29, 2025
d06448a
Link gbnf_to_lark.py script; fix links; refer to llg docs for lexemes
mmoskal Jan 29, 2025
59da969
simplify #includes
mmoskal Jan 30, 2025
d59d939
improve doc string for LLAMA_LLGUIDANCE
mmoskal Jan 30, 2025
6b2de55
Merge branch 'master' into llg
mmoskal Jan 31, 2025
a049afb
typo in merge
mmoskal Jan 31, 2025
7057589
bump llguidance to 0.6.12
mmoskal Jan 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,36 @@ jobs:
cd build
ctest -L main --verbose --timeout 900

ubuntu-latest-llguidance:
runs-on: ubuntu-latest

steps:
- name: Clone
id: checkout
uses: actions/checkout@v4

- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get install build-essential

- name: Build
id: cmake_build
run: |
mkdir build
cd build
cmake .. \
-DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_LLGUIDANCE=ON
cmake --build . --config Release -j $(nproc)

- name: Test
id: cmake_test
run: |
cd build
ctest -L main --verbose --timeout 900

ubuntu-latest-cmake-rpc:
runs-on: ubuntu-latest

Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})

# 3rd party libs
option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF)
option(LLAMA_LLGUIDANCE "llama: build LLGuidance library for structured output" OFF)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should somehow be indicated that it is used by the common library. I am not sure what would be the best way. Maybe rename it to LLAMA_COMMON_LLGUIDANCE? But even if we leave it like this, it's ok - just making a note.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-worded the help string, maybe that is enough for now?


# Required for relocatable CMake package
include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake)
Expand Down
27 changes: 27 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ add_library(${TARGET} STATIC
console.h
json-schema-to-grammar.cpp
json.hpp
llguidance.cpp
log.cpp
log.h
minja.hpp
Expand All @@ -89,6 +90,32 @@ if (LLAMA_CURL)
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
endif ()

if (LLAMA_LLGUIDANCE)
include(ExternalProject)
set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source)
set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release)
ExternalProject_Add(llguidance_ext
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
GIT_TAG 7c96a46ba52a3929fa1e57355e6fb12c5abb78db
PREFIX ${CMAKE_BINARY_DIR}/llguidance
SOURCE_DIR ${LLGUIDANCE_SRC}
BUILD_IN_SOURCE TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND cargo build --release
INSTALL_COMMAND ""
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/libllguidance.a ${LLGUIDANCE_PATH}/llguidance.h
UPDATE_COMMAND ""
)
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_LLGUIDANCE)

add_library(llguidance STATIC IMPORTED)
set_target_properties(llguidance PROPERTIES IMPORTED_LOCATION ${LLGUIDANCE_PATH}/libllguidance.a)
add_dependencies(llguidance llguidance_ext)

target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH})
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance)
endif ()

target_include_directories(${TARGET} PUBLIC .)
target_compile_features (${TARGET} PUBLIC cxx_std_17)
target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads)
9 changes: 8 additions & 1 deletion common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,14 @@ class SchemaConverter {
}
};

std::string json_schema_to_grammar(const json & schema) {
std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
#ifdef LLAMA_USE_LLGUIDANCE
if (!force_gbnf) {
return "%llguidance {}\nstart: %json " + schema.dump();
}
#else
(void)force_gbnf;
#endif // LLAMA_USE_LLGUIDANCE
return build_grammar([&](const llama_grammar_builder & callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);
Expand Down
3 changes: 2 additions & 1 deletion common/json-schema-to-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"

std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema,
bool force_gbnf = false);

struct llama_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;
Expand Down
266 changes: 266 additions & 0 deletions common/llguidance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
#ifdef LLAMA_USE_LLGUIDANCE

# include "common.h"
# include "sampling.h"
# include "log.h"
# include "llama.h"

# include "llguidance.h"

# include <cmath>

struct llama_sampler_llg {
const llama_vocab * vocab;
std::string grammar_kind;
std::string grammar_data;
LlgTokenizer * tokenizer;
LlgConstraint * grammar;
LlgMaskResult llg_res;
bool has_llg_res;
};

static LlgConstraint * llama_sampler_llg_new(LlgTokenizer * tokenizer, const char * grammar_kind,
const char * grammar_data) {
LlgConstraintInit cinit;
llg_constraint_init_set_defaults(&cinit, tokenizer);
const char * log_level = getenv("LLGUIDANCE_LOG_LEVEL");
if (log_level && *log_level) {
cinit.log_stderr_level = atoi(log_level);
}
auto c = llg_new_constraint_any(&cinit, grammar_kind, grammar_data);
if (llg_get_error(c)) {
LOG_ERR("llg error: %s\n", llg_get_error(c));
llg_free_constraint(c);
return nullptr;
}
return c;
}

static const char * llama_sampler_llg_name(const llama_sampler * /*smpl*/) {
return "llguidance";
}

static void llama_sampler_llg_accept_impl(llama_sampler * smpl, llama_token token) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
LlgCommitResult res;
llg_commit_token(ctx->grammar, token, &res);
ctx->has_llg_res = false;
}
}

static void llama_sampler_llg_apply(llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (ctx->grammar) {
if (!ctx->has_llg_res) {
if (llg_compute_mask(ctx->grammar, &ctx->llg_res) == 0) {
ctx->has_llg_res = true;
} else {
LOG_ERR("llg error: %s\n", llg_get_error(ctx->grammar));
llg_free_constraint(ctx->grammar);
ctx->grammar = nullptr;
}
}
if (ctx->has_llg_res) {
if (ctx->llg_res.is_stop) {
for (size_t i = 0; i < cur_p->size; ++i) {
if (!llama_vocab_is_eog(ctx->vocab, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY;
}
}
} else {
const uint32_t * mask = ctx->llg_res.sample_mask;
for (size_t i = 0; i < cur_p->size; ++i) {
auto token = cur_p->data[i].id;
if ((mask[token / 32] & (1 << (token % 32))) == 0) {
cur_p->data[i].logit = -INFINITY;
}
}
}
}
}
}

static void llama_sampler_llg_reset(llama_sampler * smpl) {
auto * ctx = (llama_sampler_llg *) smpl->ctx;
if (!ctx->grammar) {
return;
}

auto * grammar_new = llama_sampler_llg_new(ctx->tokenizer, ctx->grammar_kind.c_str(), ctx->grammar_data.c_str());
llg_free_constraint(ctx->grammar);
ctx->grammar = grammar_new;
ctx->has_llg_res = false;
}

static llama_sampler * llama_sampler_llg_clone(const llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_llg *) smpl->ctx;

auto * result = llama_sampler_init_llg(ctx->vocab, nullptr, nullptr);

// copy the state
{
auto * result_ctx = (llama_sampler_llg *) result->ctx;

if (ctx->grammar) {
result_ctx->grammar_kind = ctx->grammar_kind;
result_ctx->grammar_data = ctx->grammar_data;
result_ctx->grammar = llg_clone_constraint(ctx->grammar);
result_ctx->tokenizer = llg_clone_tokenizer(ctx->tokenizer);
}
}

return result;
}

static void llama_sampler_llg_free(llama_sampler * smpl) {
const auto * ctx = (llama_sampler_llg *) smpl->ctx;

if (ctx->grammar) {
llg_free_constraint(ctx->grammar);
llg_free_tokenizer(ctx->tokenizer);
}

delete ctx;
}

static llama_sampler_i llama_sampler_llg_i = {
/* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free,
};

static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
uint32_t * output_tokens, size_t output_tokens_len) {
const llama_vocab * vocab = (const llama_vocab *) user_data;
int r = 0;
try {
r = llama_tokenize(vocab, (const char *) bytes, bytes_len, (int32_t *) output_tokens, output_tokens_len, false,
true);
} catch (const std::exception & e) {
GGML_ABORT("llama_tokenize failed: %s\n", e.what());
}
if (r < 0) {
return -r;
}
return r;
}

static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) {
// TODO store the tokenizer in the vocab somehow
static const llama_vocab * vocab_cache;
static LlgTokenizer * tokenizer_cache;

if (vocab_cache == vocab) {
return llg_clone_tokenizer(tokenizer_cache);
}

auto tok_eos = llama_vocab_eot(vocab);
if (tok_eos == LLAMA_TOKEN_NULL) {
tok_eos = llama_vocab_eos(vocab);
}

size_t vocab_size = llama_vocab_n_tokens(vocab);

auto token_lens = new uint32_t[vocab_size];
// we typically have ~7 bytes per token; let's go on the safe side here
auto token_bytes_size = vocab_size * 16 + 1024 * 1024;
auto token_bytes = new uint8_t[token_bytes_size];

size_t offset = 0;
for (size_t i = 0; i < vocab_size; i++) {
size_t max_token = 1024;
if (token_bytes_size - offset < max_token) {
GGML_ABORT("token_bytes buffer too small\n");
}

llama_token token = i;
auto dp = (char *) token_bytes + offset;
auto size = llama_detokenize(vocab, &token, 1, dp, max_token, false, false);
if (size < 0) {
GGML_ABORT("llama_detokenize failed\n");
}
if (size == 0) {
size = llama_detokenize(vocab, &token, 1, dp + 1, max_token - 1, false, true);
if (size < 0) {
GGML_ABORT("llama_detokenize failed\n");
}
if (size != 0) {
*dp = '\xff'; // special token prefix marker
size += 1;
}
}

token_lens[i] = size;
offset += size;
}

LlgTokenizerInit tinit = {
/* .vocab_size = */ (uint32_t) vocab_size,
/* .tok_eos = */ (uint32_t) tok_eos,
/* .token_lens = */ token_lens,
/* .token_bytes = */ token_bytes,
/* .tokenizer_json = */ nullptr,
/* .tokenize_assumes_string = */ true,
/* .tokenize_fn = */ llama_sampler_llg_tokenize_fn,
/* .use_approximate_greedy_tokenize_fn = */ false,
/* .tokenize_user_data = */ vocab,
};

char error_buffer[1024];
LlgTokenizer * tokenizer = llg_new_tokenizer(&tinit, error_buffer, sizeof(error_buffer));

delete[] token_bytes;
delete[] token_lens;

if (tokenizer == nullptr) {
LOG_ERR("llg tokenizer error: %s\n", error_buffer);
return tokenizer;
}

if (tokenizer_cache) {
llg_free_tokenizer(tokenizer_cache);
}
vocab_cache = vocab;
tokenizer_cache = tokenizer;

return llg_clone_tokenizer(tokenizer_cache);
}

llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind,
const char * grammar_data) {
auto * ctx = new llama_sampler_llg;

if (grammar_kind != nullptr && grammar_kind[0] != '\0') {
auto tokenizer = llama_sampler_llg_new_tokenizer(vocab);
*ctx = {
/* .vocab = */ vocab,
/* .grammar_kind = */ grammar_kind,
/* .grammar_data = */ grammar_data,
/* .tokenizer = */ tokenizer,
/* .grammar = */ llama_sampler_llg_new(tokenizer, grammar_kind, grammar_data),
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
} else {
*ctx = {
/* .vocab = */ vocab,
/* .grammar_kind = */ {},
/* .grammar_data = */ {},
/* .tokenizer = */ nullptr,
/* .grammar = */ nullptr,
/* .llg_res = */ {},
/* .has_llg_res = */ false,
};
}

return new llama_sampler{
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};
}

#endif // LLAMA_USE_LLGUIDANCE
Loading
Loading