Skip to content

Commit

Permalink
[python-package][R-package] load parameters from model file (fixes #2613
Browse files Browse the repository at this point in the history
) (#5424)
  • Loading branch information
jmoralez authored Oct 11, 2022
1 parent c134d3d commit 8b72084
Show file tree
Hide file tree
Showing 13 changed files with 368 additions and 6 deletions.
15 changes: 15 additions & 0 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ Booster <- R6::R6Class(
LGBM_BoosterCreateFromModelfile_R
, modelfile
)
params <- private$get_loaded_param(handle)

} else if (!is.null(model_str)) {

Expand Down Expand Up @@ -727,6 +728,20 @@ Booster <- R6::R6Class(

},

get_loaded_param = function(handle) {
params_str <- .Call(
LGBM_BoosterGetLoadedParam_R
, handle
)
params <- jsonlite::fromJSON(params_str)
if ("interaction_constraints" %in% names(params)) {
params[["interaction_constraints"]] <- lapply(params[["interaction_constraints"]], function(x) x + 1L)
}

return(params)

},

inner_eval = function(data_name, data_idx, feval = NULL) {

# Check for unknown dataset (over the maximum provided range)
Expand Down
22 changes: 22 additions & 0 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,27 @@ SEXP LGBM_DumpParamAliases_R() {
R_API_END();
}

SEXP LGBM_BoosterGetLoadedParam_R(SEXP handle) {
SEXP cont_token = PROTECT(R_MakeUnwindCont());
R_API_BEGIN();
_AssertBoosterHandleNotNull(handle);
SEXP params_str;
int64_t out_len = 0;
int64_t buf_len = 1024 * 1024;
std::vector<char> inner_char_buf(buf_len);
CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), buf_len, &out_len, inner_char_buf.data()));
// if aliases string was larger than the initial buffer, allocate a bigger buffer and try again
if (out_len > buf_len) {
inner_char_buf.resize(out_len);
CHECK_CALL(LGBM_BoosterGetLoadedParam(R_ExternalPtrAddr(handle), out_len, &out_len, inner_char_buf.data()));
}
params_str = PROTECT(safe_R_string(static_cast<R_xlen_t>(1), &cont_token));
SET_STRING_ELT(params_str, 0, safe_R_mkChar(inner_char_buf.data(), &cont_token));
UNPROTECT(2);
return params_str;
R_API_END();
}

// .Call() calls
static const R_CallMethodDef CallEntries[] = {
{"LGBM_HandleIsNull_R" , (DL_FUNC) &LGBM_HandleIsNull_R , 1},
Expand Down Expand Up @@ -1211,6 +1232,7 @@ static const R_CallMethodDef CallEntries[] = {
{"LGBM_BoosterResetParameter_R" , (DL_FUNC) &LGBM_BoosterResetParameter_R , 2},
{"LGBM_BoosterGetNumClasses_R" , (DL_FUNC) &LGBM_BoosterGetNumClasses_R , 2},
{"LGBM_BoosterGetNumFeature_R" , (DL_FUNC) &LGBM_BoosterGetNumFeature_R , 1},
{"LGBM_BoosterGetLoadedParam_R" , (DL_FUNC) &LGBM_BoosterGetLoadedParam_R , 1},
{"LGBM_BoosterUpdateOneIter_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIter_R , 1},
{"LGBM_BoosterUpdateOneIterCustom_R" , (DL_FUNC) &LGBM_BoosterUpdateOneIterCustom_R , 4},
{"LGBM_BoosterRollbackOneIter_R" , (DL_FUNC) &LGBM_BoosterRollbackOneIter_R , 1},
Expand Down
9 changes: 9 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ LIGHTGBM_C_EXPORT SEXP LGBM_BoosterLoadModelFromString_R(
SEXP model_str
);

/*!
* \brief Get parameters as JSON string.
* \param handle Booster handle
* \return R character vector (length=1) with parameters in JSON format
*/
LIGHTGBM_C_EXPORT SEXP LGBM_BoosterGetLoadedParam_R(
SEXP handle
);

/*!
* \brief Merge model in two Boosters to first handle
* \param handle handle primary Booster handle, will merge other handle to this
Expand Down
24 changes: 18 additions & 6 deletions R-package/tests/testthat/test_lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,24 @@ test_that("Loading a Booster from a text file works", {
data(agaricus.test, package = "lightgbm")
train <- agaricus.train
test <- agaricus.test
params <- list(
num_leaves = 4L
, boosting = "rf"
, bagging_fraction = 0.8
, bagging_freq = 1L
, boost_from_average = FALSE
, categorical_feature = c(1L, 2L)
, interaction_constraints = list(c(1L, 2L), 1L)
, feature_contri = rep(0.5, ncol(train$data))
, metric = c("mape", "average_precision")
, learning_rate = 1.0
, objective = "binary"
, verbosity = VERBOSITY
)
bst <- lightgbm(
data = as.matrix(train$data)
, label = train$label
, params = list(
num_leaves = 4L
, learning_rate = 1.0
, objective = "binary"
, verbose = VERBOSITY
)
, params = params
, nrounds = 2L
)
expect_true(lgb.is.Booster(bst))
Expand All @@ -199,6 +208,9 @@ test_that("Loading a Booster from a text file works", {
)
pred2 <- predict(bst2, test$data)
expect_identical(pred, pred2)

# check that the parameters are loaded correctly
expect_equal(bst2$params[names(params)], params)
})

test_that("boosters with linear models at leaves can be written to text file and re-loaded successfully", {
Expand Down
27 changes: 27 additions & 0 deletions helpers/parameter_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
along with parameters description in LightGBM/docs/Parameters.rst file
from the information in LightGBM/include/LightGBM/config.h file.
"""
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
Expand Down Expand Up @@ -373,6 +374,32 @@ def gen_parameter_code(
}
"""
str_to_write += """const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
static std::unordered_map<std::string, std::string> map({"""
int_t_pat = re.compile(r'int\d+_t')
# the following are stored as comma separated strings but are arrays in the wrappers
overrides = {
'categorical_feature': 'vector<int>',
'ignore_column': 'vector<int>',
'interaction_constraints': 'vector<vector<int>>',
}
for x in infos:
for y in x:
name = y["name"][0]
if name == 'task':
continue
if name in overrides:
param_type = overrides[name]
else:
param_type = int_t_pat.sub('int', y["inner_type"][0]).replace('std::', '')
str_to_write += '\n {"' + name + '", "' + param_type + '"},'
str_to_write += """
});
return map;
}
"""

str_to_write += "} // namespace LightGBM\n"
with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write)
Expand Down
2 changes: 2 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ class LIGHTGBM_EXPORT Boosting {
*/
static Boosting* CreateBoosting(const std::string& type, const char* filename);

virtual std::string GetLoadedParam() const = 0;

virtual bool IsLinear() const { return false; }

virtual std::string ParserConfigStr() const = 0;
Expand Down
14 changes: 14 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,20 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterLoadModelFromString(const char* model_str,
int* out_num_iterations,
BoosterHandle* out);

/*!
* \brief Get parameters as JSON string.
* \param handle Handle of booster.
* \param buffer_len Allocated space for string.
* \param[out] out_len Actual size of string.
* \param[out] out_str JSON string containing parameters.
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterGetLoadedParam(BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str);


/*!
* \brief Free space for booster.
* \param handle Handle of booster to be freed
Expand Down
1 change: 1 addition & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ struct Config {
static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector;
static const std::unordered_map<std::string, std::string>& ParameterTypes();
static const std::string DumpAliases();

private:
Expand Down
25 changes: 25 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2816,6 +2816,9 @@ def __init__(
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
if params:
_log_warning('Ignoring params argument, using parameters from model file.')
params = self._get_loaded_param()
elif model_str is not None:
self.model_from_string(model_str)
else:
Expand Down Expand Up @@ -2864,6 +2867,28 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
state['handle'] = handle
self.__dict__.update(state)

def _get_loaded_param(self) -> Dict[str, Any]:
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterGetLoadedParam(
self.handle,
ctypes.c_int64(buffer_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
actual_len = tmp_out_len.value
# if buffer length is not long enough, re-allocate a buffer
if actual_len > buffer_len:
string_buffer = ctypes.create_string_buffer(actual_len)
ptr_string_buffer = ctypes.c_char_p(*[ctypes.addressof(string_buffer)])
_safe_call(_LIB.LGBM_BoosterGetLoadedParam(
self.handle,
ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len),
ptr_string_buffer))
return json.loads(string_buffer.value.decode('utf-8'))

def free_dataset(self) -> "Booster":
"""Free Booster's Datasets.
Expand Down
54 changes: 54 additions & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,60 @@ class GBDT : public GBDTBase {
*/
int GetCurrentIteration() const override { return static_cast<int>(models_.size()) / num_tree_per_iteration_; }

/*!
* \brief Get parameters as a JSON string
*/
std::string GetLoadedParam() const override {
if (loaded_parameter_.empty()) {
return std::string("{}");
}
const auto param_types = Config::ParameterTypes();
const auto lines = Common::Split(loaded_parameter_.c_str(), "\n");
bool first = true;
std::stringstream str_buf;
str_buf << "{";
for (const auto& line : lines) {
const auto pair = Common::Split(line.c_str(), ":");
if (pair[1] == " ]")
continue;
if (first) {
first = false;
str_buf << "\"";
} else {
str_buf << ",\"";
}
const auto param = pair[0].substr(1);
const auto value_str = pair[1].substr(1, pair[1].size() - 2);
const auto param_type = param_types.at(param);
str_buf << param << "\": ";
if (param_type == "string") {
str_buf << "\"" << value_str << "\"";
} else if (param_type == "int") {
int value;
Common::Atoi(value_str.c_str(), &value);
str_buf << value;
} else if (param_type == "double") {
double value;
Common::Atof(value_str.c_str(), &value);
str_buf << value;
} else if (param_type == "bool") {
bool value = value_str == "1";
str_buf << std::boolalpha << value;
} else if (param_type.substr(0, 6) == "vector") {
str_buf << "[";
if (param_type.substr(7, 6) == "string") {
const auto parts = Common::Split(value_str.c_str(), ",");
str_buf << "\"" << Common::Join(parts, "\",\"") << "\"";
} else {
str_buf << value_str;
}
str_buf << "]";
}
}
str_buf << "}";
return str_buf.str();
}

/*!
* \brief Can use early stopping for prediction or not
* \return True if cannot use early stopping for prediction
Expand Down
15 changes: 15 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,21 @@ int LGBM_BoosterLoadModelFromString(
API_END();
}

int LGBM_BoosterGetLoadedParam(
BoosterHandle handle,
int64_t buffer_len,
int64_t* out_len,
char* out_str) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
std::string params = ref_booster->GetBoosting()->GetLoadedParam();
*out_len = static_cast<int64_t>(params.size()) + 1;
if (*out_len <= buffer_len) {
std::memcpy(out_str, params.c_str(), *out_len);
}
API_END();
}

#ifdef _MSC_VER
#pragma warning(disable : 4702)
#endif
Expand Down
Loading

0 comments on commit 8b72084

Please sign in to comment.