Skip to content

Commit

Permalink
common: utils to split / join / repeat strings (from json converter) (
Browse files Browse the repository at this point in the history
ggerganov#11342)

* Factor string_join, string_split, string_repeat into common

* json: refactor to surface a versatile builder

* Update common.cpp
  • Loading branch information
ochafik authored Jan 22, 2025
1 parent 3e3357f commit a94f3b2
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 64 deletions.
42 changes: 42 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,48 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder);
}

std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) {
if (i > 0) {
result << separator;
}
result << values[i];
}
return result.str();
}

std::vector<std::string> string_split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> parts;
size_t start = 0;
size_t end = str.find(delimiter);

while (end != std::string::npos) {
parts.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}

parts.push_back(str.substr(start));

return parts;
}

std::string string_repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}

std::string result;
result.reserve(str.length() * n);

for (size_t i = 0; i < n; ++i) {
result += str;
}

return result;
}

std::string string_from(bool value) {
return value ? "true" : "false";
}
Expand Down
4 changes: 4 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,10 @@ std::string string_format(const char * fmt, ...);
std::string string_strip(const std::string & str);
std::string string_get_sortable_timestamp();

std::string string_join(const std::vector<std::string> & values, const std::string & separator);
std::vector<std::string> string_split(const std::string & str, const std::string & delimiter);
std::string string_repeat(const std::string & str, size_t n);

void string_replace_all(std::string & s, const std::string & search, const std::string & replace);

template<class T>
Expand Down
98 changes: 35 additions & 63 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "json-schema-to-grammar.h"
#include "common.h"

#include <algorithm>
#include <fstream>
#include <map>
Expand All @@ -11,11 +13,6 @@

using json = nlohmann::ordered_json;

template <typename Iterator>
static std::string join(Iterator begin, Iterator end, const std::string & separator);

static std::string repeat(const std::string & str, size_t n);

static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max();

Expand Down Expand Up @@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
if (sub_len > 0) {
auto from_sub = from.substr(i + 1);
auto to_sub = to.substr(i + 1);
auto sub_zeros = repeat("0", sub_len);
auto sub_nines = repeat("9", sub_len);
auto sub_zeros = string_repeat("0", sub_len);
auto sub_nines = string_repeat("9", sub_len);

auto to_reached = false;
out << "(";
Expand Down Expand Up @@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
auto max_digits = max_s.length();

for (auto digits = min_digits; digits < max_digits; digits++) {
uniform_range(min_s, repeat("9", digits));
min_s = "1" + repeat("0", digits);
uniform_range(min_s, string_repeat("9", digits));
min_s = "1" + string_repeat("0", digits);
out << " | ";
}
uniform_range(min_s, max_s);
Expand Down Expand Up @@ -318,49 +315,6 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};

template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator) {
std::ostringstream result;
if (begin != end) {
result << *begin;
for (Iterator it = begin + 1; it != end; ++it) {
result << separator << *it;
}
}
return result.str();
}

static std::vector<std::string> split(const std::string & str, const std::string & delimiter) {
std::vector<std::string> tokens;
size_t start = 0;
size_t end = str.find(delimiter);

while (end != std::string::npos) {
tokens.push_back(str.substr(start, end - start));
start = end + delimiter.length();
end = str.find(delimiter, start);
}

tokens.push_back(str.substr(start));

return tokens;
}

static std::string repeat(const std::string & str, size_t n) {
if (n == 0) {
return "";
}

std::string result;
result.reserve(str.length() * n);

for (size_t i = 0; i < n; ++i) {
result += str;
}

return result;
}

static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function<std::string(const std::smatch &)> & replacement) {
std::smatch match;
std::string result;
Expand Down Expand Up @@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) {

class SchemaConverter {
private:
friend std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);
std::function<json(const std::string &)> _fetch_json;
bool _dotall;
std::map<std::string, std::string> _rules;
Expand Down Expand Up @@ -418,7 +373,7 @@ class SchemaConverter {
for (size_t i = 0; i < alt_schemas.size(); i++) {
rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i)));
}
return join(rules.begin(), rules.end(), " | ");
return string_join(rules, " | ");
}

std::string _visit_pattern(const std::string & pattern, const std::string & name) {
Expand Down Expand Up @@ -481,7 +436,7 @@ class SchemaConverter {
for (const auto & item : ret) {
results.push_back(to_rule(item));
}
return std::make_pair(join(results.begin(), results.end(), " "), false);
return std::make_pair(string_join(results, " "), false);
};

while (i < length) {
Expand Down Expand Up @@ -539,7 +494,7 @@ class SchemaConverter {
}
curly_brackets += '}';
i++;
auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ",");
int min_times = 0;
int max_times = std::numeric_limits<int>::max();
try {
Expand Down Expand Up @@ -854,7 +809,7 @@ class SchemaConverter {
return;
}
std::string pointer = ref.substr(ref.find('#') + 1);
std::vector<std::string> tokens = split(pointer, "/");
std::vector<std::string> tokens = string_split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
Expand Down Expand Up @@ -905,7 +860,7 @@ class SchemaConverter {
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
Expand Down Expand Up @@ -1019,10 +974,10 @@ class SchemaConverter {

void check_errors() {
if (!_errors.empty()) {
throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n"));
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
}
if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str());
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());
}
}

Expand All @@ -1036,10 +991,27 @@ class SchemaConverter {
};

std::string json_schema_to_grammar(const json & schema) {
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
auto copy = schema;
converter.resolve_refs(copy, "input");
converter.visit(copy, "");
return build_grammar([&](const llama_grammar_builder & callbacks) {
auto copy = schema;
callbacks.resolve_refs(copy);
callbacks.add_schema("", copy);
});
}

std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb) {
SchemaConverter converter([&](const std::string &) { return json(); }, /* dotall= */ false);
llama_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule);
},
/* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) {
return converter.visit(schema, name == "root" ? "" : name);
},
/* .resolve_refs = */ [&](nlohmann::ordered_json & schema) {
converter.resolve_refs(schema, "");
}
};
cb(builder);
converter.check_errors();
return converter.format_grammar();
}
10 changes: 9 additions & 1 deletion common/json-schema-to-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,12 @@
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"

std::string json_schema_to_grammar(const nlohmann::ordered_json& schema);
std::string json_schema_to_grammar(const nlohmann::ordered_json & schema);

struct llama_grammar_builder {
std::function<std::string(const std::string &, const std::string &)> add_rule;
std::function<std::string(const std::string &, const nlohmann::ordered_json &)> add_schema;
std::function<void(nlohmann::ordered_json &)> resolve_refs;
};

std::string build_grammar(const std::function<void(const llama_grammar_builder &)> & cb);

0 comments on commit a94f3b2

Please sign in to comment.