From 27670cdb58a877112a7e3cc7d0eb76c135a1e8b6 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 1 Mar 2021 19:18:04 +0800 Subject: [PATCH 01/70] [python-package] create Dataset from sampled data. --- include/LightGBM/c_api.h | 11 +++ include/LightGBM/dataset_loader.h | 2 +- python-package/lightgbm/basic.py | 115 ++++++++++++++++++++++++++++++ src/c_api.cpp | 42 +++++++++-- src/io/dataset_loader.cpp | 46 +++++++++++- 5 files changed, 209 insertions(+), 7 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 025633c97bde..2b0294ec37c1 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -214,6 +214,17 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, const DatasetHandle reference, DatasetHandle* out); +/*! + * \brief Create sample indices for total nrow. + * \param total_nrow Number of all data rows + * \param parameters Additional parameters, specify sample count and random seed in parameter + * \param[out] out Created indices, type is int32_t, caller should insure out contains enough space to hold indices + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t total_nrow, + const char* parameters, + void* out); + /*! * \brief Create dataset from dense matrix. * \param data Pointer to the data space diff --git a/include/LightGBM/dataset_loader.h b/include/LightGBM/dataset_loader.h index e72dd4910804..cfcb716e4eb0 100644 --- a/include/LightGBM/dataset_loader.h +++ b/include/LightGBM/dataset_loader.h @@ -29,7 +29,7 @@ class DatasetLoader { LIGHTGBM_EXPORT Dataset* ConstructFromSampleData(double** sample_values, int** sample_indices, int num_col, const int* num_per_col, - size_t total_sample_size, data_size_t num_data); + size_t total_sample_size, data_size_t num_data, const std::string& dump_filename); /*! \brief Disable copy */ DatasetLoader& operator=(const DatasetLoader&) = delete; diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 121b371459d1..60980a6e1df5 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1111,6 +1111,7 @@ def __init__(self, data, label=None, reference=None, self.feature_penalty = None self.monotone_constraints = None self.version = 0 + self._start_row = 0 # Used when pushing rows one by one. def __del__(self): try: @@ -1118,6 +1119,120 @@ def __del__(self): except AttributeError: pass + # TODO how to keep the default value the same with C++ config.h + DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 + + def create_sample_indices(self, total_nrow): + """Create sample indices for the given parameter of the Dataset. + + Parameters + ---------- + total_nrow: int + Total number of rows to sample from. + If Dataset has multiple input data, this should be the sum of rows of every file. + + Returns + ------- + indices: numpy array + Indices for sampled data. + """ + param_str = param_dict_to_str(self.params) + sample_cnt = self.params.get("bin_construct_sample_cnt", + self.DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT) + indices = np.zeros(sample_cnt, dtype=np.int32) + ptr_data, _, _ = c_int_array(indices) + + _safe_call(_LIB.LGBM_SampleIndices( + ctypes.c_int(total_nrow), + c_str(param_str), + ptr_data, + )) + return indices + + def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): + """Get the used parameters in the Dataset. + + Parameters + ---------- + sample_data: 2d numpy array (dtype must be double, in F order) + Sample data value in row major order. + Note: each column contains len(sample_indices[i]) number of values. + sample_indices: List[List[int]] + Sample data row index for each column. + sample_cnt: int + Number of samples. + total_nrow: int + Total number of rows for all input file. + + Returns + ------- + self : Dataset + Constructed Dataset object. + """ + if len(sample_data.shape) != 2: + raise ValueError('sample_data numpy.ndarray must be 2 dimensional') + assert sample_data.dtype == np.double, "sample data type {} is not double".format(sample_data.dtype) + assert sample_data.shape[1] == len(sample_indices), "#sample data column != #column indices" + + ncol = len(sample_indices) + + for i in range(ncol): + if sample_indices[i].dtype != np.int32: + raise ValueError("sample_indices[{}] type {} is not int32".format(i, sample_indices[i].dtype)) + + # c type: double** + # each double* element points to start of each column of sample data. + sample_col_ptr = (ctypes.POINTER(ctypes.c_double) * ncol)() + # c type int** + # each int* points to start of indices for each column + indices_col_ptr = (ctypes.POINTER(ctypes.c_int32) * ncol)() + for i in range(ncol): + sample_col_ptr[i] = c_float_array(sample_data[:, i])[0] + indices_col_ptr[i] = c_int_array(sample_indices[i])[0] + + num_per_col = np.array([len(d) for d in sample_indices], dtype=np.int32) + num_per_col_ptr, _, _ = c_int_array(num_per_col) + + self.handle = ctypes.c_void_p() + params_str = param_dict_to_str(self.params) + _safe_call(_LIB.LGBM_DatasetCreateFromSampledColumn( + ctypes.cast(sample_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))), + ctypes.cast(indices_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_int32))), + ctypes.c_int32(ncol), + num_per_col_ptr, + ctypes.c_int32(sample_cnt), + ctypes.c_int32(total_nrow), + c_str(params_str), + ctypes.byref(self.handle), + )) + return self + + def push_rows(self, data): + """Add rows to Dataset. + + Args: + data: numpy 1-D array + + Returns + ------- + self : Dataset + Dataset object. + """ + nrow, ncol = data.shape + data = np.array(data.reshape(data.size), dtype=data.dtype, copy=False) + data_ptr, data_type, _ = c_float_array(data) + + _safe_call(_LIB.LGBM_DatasetPushRows( + self.handle, + data_ptr, + data_type, + ctypes.c_int32(nrow), + ctypes.c_int32(ncol), + ctypes.c_int32(self._start_row), + )) + self._start_row += nrow + return self + def get_params(self): """Get the used parameters in the Dataset. diff --git a/src/c_api.cpp b/src/c_api.cpp index 3d20d92da70d..fc1881ebd303 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -938,7 +938,7 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data, DatasetLoader loader(config, nullptr, 1, nullptr); *out = loader.ConstructFromSampleData(sample_data, sample_indices, ncol, num_per_col, num_sample_row, - static_cast(num_total_row)); + static_cast(num_total_row), ""); API_END(); } @@ -960,6 +960,7 @@ int LGBM_DatasetPushRows(DatasetHandle dataset, int32_t nrow, int32_t ncol, int32_t start_row) { + Log::Info("start_row %d nrow: %d", start_row, nrow); API_BEGIN(); auto p_dataset = reinterpret_cast(dataset); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1); @@ -977,6 +978,7 @@ int LGBM_DatasetPushRows(DatasetHandle dataset, } OMP_THROW_EX(); if (start_row + nrow == p_dataset->num_data()) { + Log::Info("Dataset PushRows FinishLoad"); p_dataset->FinishLoad(); } API_END(); @@ -1015,6 +1017,7 @@ int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, API_END(); } + int LGBM_DatasetCreateFromMat(const void* data, int data_type, int32_t nrow, @@ -1035,6 +1038,35 @@ int LGBM_DatasetCreateFromMat(const void* data, } +static inline std::vector CreateSampleIndices(const Config& config, int32_t total_nrow) { + Random rand(config.data_random_seed); + int sample_cnt = static_cast(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt); + return rand.Sample(total_nrow, sample_cnt); +} + + +int LGBM_SampleIndices(int32_t total_nrow, + const char* parameters, + void* out) { + // This API is to keep python binding's behavior the same with C++ implementation. + // Sample count, random seed etc. should be provided in parameters. + API_BEGIN(); + if (out == nullptr) { + Log::Fatal("sample indicies output is nullptr"); + } + auto param = Config::Str2Map(parameters); + Config config; + config.Set(param); + + auto sample_indices = CreateSampleIndices(config, total_nrow); + + static_assert (sizeof(int) == 4, "int size is not 4"); + memcpy(out, sample_indices.data(), sizeof(int32_t) * sample_indices.size()); + + API_END(); +} + + int LGBM_DatasetCreateFromMats(int32_t nmat, const void** data, int data_type, @@ -1093,7 +1125,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, Vector2Ptr(&sample_idx).data(), ncol, VectorSize(sample_values).data(), - sample_cnt, total_nrow)); + sample_cnt, total_nrow, "")); } else { ret.reset(new Dataset(total_nrow)); ret->CreateValid( @@ -1172,7 +1204,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, Vector2Ptr(&sample_idx).data(), static_cast(num_col), VectorSize(sample_values).data(), - sample_cnt, nrow)); + sample_cnt, nrow, "")); } else { ret.reset(new Dataset(nrow)); ret->CreateValid( @@ -1243,7 +1275,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, Vector2Ptr(&sample_idx).data(), static_cast(num_col), VectorSize(sample_values).data(), - sample_cnt, nrow)); + sample_cnt, nrow, "")); } else { ret.reset(new Dataset(nrow)); ret->CreateValid( @@ -1319,7 +1351,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, Vector2Ptr(&sample_idx).data(), static_cast(sample_values.size()), VectorSize(sample_values).data(), - sample_cnt, nrow)); + sample_cnt, nrow, "")); } else { ret.reset(new Dataset(nrow)); ret->CreateValid( diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index 2d2c4d622b1c..87aac784533d 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -617,10 +617,38 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b return dataset.release(); } +// To help verify whether sample data is changed when using different language bindings. +static void DumpSampleData(const std::string& sample_filename, double** sample_values, + int** sample_indices, int num_col, const int* num_per_col, + size_t total_sample_size, data_size_t num_data) { + Log::Info("dump sample data to %s", sample_filename.c_str()); + std::ofstream out(sample_filename); + + out << "num_col: " << num_col << "\n"; + out << "total_sample_size: " << total_sample_size << "\n"; + out << "num_data: " << num_data << "\n"; + + out << "num_per_col:\n"; + for (int i = 0; i < num_col; ++i) { + out << " c:" << i << "=" << num_per_col[i] << "\n"; + } + + out << "sample data:\n"; + for (int i = 0; i < num_col; ++i) { + out << " c:" << i << "\n"; + for (int j = 0; j < num_per_col[i]; ++j) { + out << " r:" << sample_indices[i][j] << "=" << sample_values[i][j] << "\n"; + } + } + out << "\n"; +} Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values, int** sample_indices, int num_col, const int* num_per_col, - size_t total_sample_size, data_size_t num_data) { + size_t total_sample_size, data_size_t num_data, const std::string& dump_filename) { + if (dump_filename != "") { + DumpSampleData(dump_filename, sample_values, sample_indices, num_col, num_per_col, total_sample_size, num_data); + } CheckSampleSize(total_sample_size, static_cast(num_data)); int num_total_features = num_col; if (Network::num_machines() > 1) { @@ -976,6 +1004,22 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, dataset->feature_groups_.clear(); dataset->num_total_features_ = std::max(static_cast(sample_values.size()), parser->NumFeatures()); + + /* + int num_col = static_cast(sample_values.size()); + std::vector sample_values_ptr(num_col); + std::vector sample_indices_ptr(num_col); + std::vector num_per_col(num_col); + for (size_t i = 0; i < sample_values.size(); ++i) { + sample_values_ptr[i] = sample_values[i].data(); + sample_indices_ptr[i] = sample_indices[i].data(); + num_per_col[i] = static_cast(sample_indices[i].size()); + } + DumpSampleData("lgbm_sample_from_text.txt", sample_values_ptr.data(), + sample_indices_ptr.data(), num_col, num_per_col.data(), + sample_data.size(), dataset->num_data()); + */ + if (num_machines > 1) { dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_); } From f16dd011c4d64464764807df55c79d118c420e48 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Tue, 16 Mar 2021 15:19:00 +0800 Subject: [PATCH 02/70] [python-package] create Dataset from List[Sequence]. 1. Use random access for data sampling 2. Support read data from multiple input files 3. Read data in batch so no need to hold all data in memory --- python-package/lightgbm/__init__.py | 4 +- python-package/lightgbm/basic.py | 165 +++++++++++++++++++++--- src/c_api.cpp | 4 +- tests/python_package_test/test_basic.py | 90 ++++++++++++- 4 files changed, 241 insertions(+), 22 deletions(-) diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index c7a1e0303dcd..e429cea80091 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -5,7 +5,7 @@ """ import os -from .basic import Booster, Dataset, register_logger +from .basic import Booster, Dataset, register_logger, Sequence from .callback import early_stopping, print_evaluation, record_evaluation, reset_parameter from .engine import CVBooster, cv, train @@ -29,7 +29,7 @@ with open(os.path.join(dir_path, 'VERSION.txt')) as version_file: __version__ = version_file.read().strip() -__all__ = ['Dataset', 'Booster', 'CVBooster', +__all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence', 'register_logger', 'train', 'cv', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 60980a6e1df5..4f449fc2ff23 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1,5 +1,6 @@ # coding: utf-8 """Wrapper for C API of LightGBM.""" +import abc import ctypes import json import os @@ -9,7 +10,7 @@ from functools import wraps from logging import Logger from tempfile import NamedTemporaryFile -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, List, Set, Union, Iterable import numpy as np import scipy.sparse @@ -18,6 +19,11 @@ from .libpath import find_lib_path +# TODO: how to keep the default values the same with C++ config.h +ZERO_THRESHOLD = 1e-35 +DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 + + class _DummyLogger: def info(self, msg): print(msg) @@ -591,6 +597,61 @@ def _load_pandas_categorical(file_name=None, model_str=None): return None +class Sequence(object): + """Generic data access interface. + + Object should support the following operations: + + # Get total row number. + >>> len(seq) + # Random access by row index. Use for data sampling. + >>> seq[10] + # Range data access. Use to read data in batch when constructing Dataset. + >>> seq[0:100] + # Optionally specify batch_size to control range data read size. + >>> seq.batch_size + + With random access, data sampling does not need to go through all data. + With range data access, there's no need to read all data into memory thus + reduce memory usage. + """ + __metaclass__ = abc.ABCMeta + + batch_size = 4096 # Defaults to read 4K rows in each batch. + + @abc.abstractmethod + def __getitem__(self, idx): # type: (Union[int, slice]) -> np.ndarray + """Return data for given row index. + + A basic implementation should look like this: + + .. code-block:: python + + if isinstance(idx, numbers.Integral): + return self.__get_one_line__(idx) + elif isinstance(idx, slice): + return np.stack(self.__get_one_line__(i) for i in range(idx.start, idx.stop)) + else: + raise TypeError("Sequence index must be integer or slice, got {}".format(type(idx))) + + Parameters + ---------- + idx : int, slice[int] + Item index. + + Returns + ------- + result : numpy 1-D array, numpy 2-D array + 1-D array if idx is int, 2-D array if idx is slice + """ + raise NotImplementedError("remove this line if subclassing") + + @abc.abstractmethod + def __len__(self): # type: () -> int + """Return row count of this sequence """ + raise NotImplementedError + + class _InnerPredictor: """_InnerPredictor of LightGBM. @@ -1119,9 +1180,6 @@ def __del__(self): except AttributeError: pass - # TODO how to keep the default value the same with C++ config.h - DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 - def create_sample_indices(self, total_nrow): """Create sample indices for the given parameter of the Dataset. @@ -1137,8 +1195,8 @@ def create_sample_indices(self, total_nrow): Indices for sampled data. """ param_str = param_dict_to_str(self.params) - sample_cnt = self.params.get("bin_construct_sample_cnt", - self.DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT) + sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT + sample_cnt = min(sample_cnt, total_nrow) indices = np.zeros(sample_cnt, dtype=np.int32) ptr_data, _, _ = c_int_array(indices) @@ -1154,10 +1212,9 @@ def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): Parameters ---------- - sample_data: 2d numpy array (dtype must be double, in F order) - Sample data value in row major order. - Note: each column contains len(sample_indices[i]) number of values. - sample_indices: List[List[int]] + sample_data: List[np.array[float64]] + Sample data for each column + sample_indices: List[np.array[int]] Sample data row index for each column. sample_cnt: int Number of samples. @@ -1169,14 +1226,13 @@ def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): self : Dataset Constructed Dataset object. """ - if len(sample_data.shape) != 2: - raise ValueError('sample_data numpy.ndarray must be 2 dimensional') - assert sample_data.dtype == np.double, "sample data type {} is not double".format(sample_data.dtype) - assert sample_data.shape[1] == len(sample_indices), "#sample data column != #column indices" + assert len(sample_data) == len(sample_indices), "#sample data column != #column indices" ncol = len(sample_indices) for i in range(ncol): + if sample_data[i].dtype != np.double: + raise ValueError("sample data type {} is not double".format(sample_data.dtype)) if sample_indices[i].dtype != np.int32: raise ValueError("sample_indices[{}] type {} is not int32".format(i, sample_indices[i].dtype)) @@ -1187,7 +1243,7 @@ def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): # each int* points to start of indices for each column indices_col_ptr = (ctypes.POINTER(ctypes.c_int32) * ncol)() for i in range(ncol): - sample_col_ptr[i] = c_float_array(sample_data[:, i])[0] + sample_col_ptr[i] = c_float_array(sample_data[i])[0] indices_col_ptr[i] = c_int_array(sample_indices[i])[0] num_per_col = np.array([len(d) for d in sample_indices], dtype=np.int32) @@ -1377,8 +1433,15 @@ def _lazy_init(self, data, label=None, reference=None, self.__init_from_csc(data, params_str, ref_dataset) elif isinstance(data, np.ndarray): self.__init_from_np2d(data, params_str, ref_dataset) - elif isinstance(data, list) and len(data) > 0 and all(isinstance(x, np.ndarray) for x in data): - self.__init_from_list_np2d(data, params_str, ref_dataset) + elif isinstance(data, Sequence): + self.__init_from_seqs([data], params_str, ref_dataset) + elif isinstance(data, list) and len(data) > 0: + if all(isinstance(x, np.ndarray) for x in data): + self.__init_from_list_np2d(data, params_str, ref_dataset) + elif all(isinstance(x, Sequence) for x in data): + self.__init_from_seqs(data, params_str, ref_dataset) + else: + raise TypeError('Data list can only be of ndarray or Sequence') elif isinstance(data, dt_DataTable): self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset) else: @@ -1406,6 +1469,74 @@ def _lazy_init(self, data, label=None, reference=None, # set feature names return self.set_feature_name(feature_name) + def __yield_row_from(self, seqs, indices): + # type: (List[Sequence], Iterable[int]) -> ... + offset = 0 + seq_id = 0 + seq = seqs[seq_id] + for row_id in indices: + assert row_id >= offset, "sample indices are expected to be monotonic" + while row_id >= offset + len(seq): + offset += len(seq) + seq_id += 1 + seq = seqs[seq_id] + id_in_seq = row_id - offset + row = seq[int(id_in_seq)] + yield row + + def __sample(self, seqs, total_nrow): + # type: (List[Sequence], int, int) -> (np.ndarray, List[np.ndarray]) + """Data Sampling. + + Mimics behavior in c_api.cpp:LGBM_DatasetCreateFromMats() + + Returns + ------- + sampled_rows, sampled_row_indices + """ + indices = self.create_sample_indices(total_nrow) + + # Select sampled rows, transpose to column order. + sampled = [row for row in self.__yield_row_from(seqs, indices)] + sampled = np.array(sampled) + sampled = sampled.T + + filtered = [] + filtered_idx = [] + sampled_row_range = np.arange(len(indices), dtype=np.int32) + for col in sampled: + col_predicate = (np.abs(col) > ZERO_THRESHOLD) | np.isnan(col) + filtered_col = col[col_predicate] + filtered_row_idx = sampled_row_range[col_predicate] + + filtered.append(filtered_col) + filtered_idx.append(filtered_row_idx) + + return filtered, filtered_idx + + def __init_from_seqs(self, seqs, params_str, ref_dataset): + # type: (List[Sequence], str, Dataset) -> None + """ + Initialize data from a Sequence object. + Sequence: Generic Data Access Object + Supports random access and access by batch if properly defined by user + """ + total_nrow = sum(len(seq) for seq in seqs) + ncol = len(seqs[0][0]) + sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT + sample_cnt = min(sample_cnt, total_nrow) + + sample_data, col_indices = self.__sample(seqs, total_nrow) + self.init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) + + for seq in seqs: + nrow = len(seq) + batch_size = seq.batch_size or Sequence.batch_size + for start in range(0, nrow, batch_size): + end = min(start+batch_size, nrow) + self.push_rows(seq[start:end]) + + def __init_from_np2d(self, mat, params_str, ref_dataset): """Initialize data from a 2-D numpy matrix.""" if len(mat.shape) != 2: diff --git a/src/c_api.cpp b/src/c_api.cpp index fc1881ebd303..28bb7c31e944 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -960,7 +960,7 @@ int LGBM_DatasetPushRows(DatasetHandle dataset, int32_t nrow, int32_t ncol, int32_t start_row) { - Log::Info("start_row %d nrow: %d", start_row, nrow); + Log::Debug("DatasetPushRows start_row: %d nrow: %d", start_row, nrow); API_BEGIN(); auto p_dataset = reinterpret_cast(dataset); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1); @@ -978,7 +978,7 @@ int LGBM_DatasetPushRows(DatasetHandle dataset, } OMP_THROW_EX(); if (start_row + nrow == p_dataset->num_data()) { - Log::Info("Dataset PushRows FinishLoad"); + Log::Debug("DatasetPushRows FinishLoad"); p_dataset->FinishLoad(); } API_END(); diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index b92c7998e554..9faab0483037 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -1,6 +1,7 @@ # coding: utf-8 import os +import filecmp import numpy as np import pytest from scipy import sparse @@ -9,9 +10,14 @@ import lightgbm as lgb from lightgbm.compat import PANDAS_INSTALLED, pd_Series - from .utils import load_breast_cancer +def rm_files(files): + if not isinstance(files, list): + files = [files] + for file in files: + if os.path.exists(file): + os.remove(file) def test_basic(tmp_path): X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), @@ -89,6 +95,88 @@ def test_basic(tmp_path): bst.predict, tname) +class NumpySequence(lgb.Sequence): + def __init__(self, ndarray): + self.ndarray = ndarray + + def __getitem__(self, idx): + # The simple implementation is just a single "return self.ndarray[idx]" + # The following is for demo and testing purpose. + if isinstance(idx, int): + return self.__get_one_line__(idx) + elif isinstance(idx, slice): + if not (idx.step is None or idx.step is 1): + raise NotImplementedError("No need to implement, caller will not set step by now") + return self.ndarray[idx.start: idx.stop] + else: + raise TypeError("Sequence Index must be an integer/list/slice, got {}".format(type(idx))) + + def __get_one_line__(self, idx): + return self.ndarray[idx] + + def __len__(self): + return len(self.ndarray) + + +@pytest.mark.parametrize('sample_count', [2, 5]) +@pytest.mark.parametrize('batch_size', [3, 20, None]) +@pytest.mark.parametrize('include_0', [False, True]) +@pytest.mark.parametrize('include_nan', [False, True]) +@pytest.mark.parametrize('num_seq', [1, 3]) +def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq): + rm_files(["seq.truth.bin", "seq.seq.bin"]) + + params = { + "bin_construct_sample_cnt": sample_count, + } + + nrow = 31 + half_nrow = nrow//2 + ncol = 11 + data = np.arange(nrow*ncol).reshape((nrow, ncol)).astype('float64') + + # total col + if include_0: + data[:, 0] = 0 + if include_nan: + data[:, 1] = np.nan + + # half col + if include_nan: + # nan col + data[0:half_nrow, 2] = np.nan + if include_0: + # 0 col + data[0:half_nrow, 3] = 0 + + # nan + 0 col + if include_nan: + data[0:half_nrow, 4] = np.nan + if include_0: + data[half_nrow:-2, 4] = 0 + + # X, Y split + X = data[:, :-1] + Y = data[:, -1] + # truth + ds = lgb.Dataset(X, label=Y, params=params) + ds.save_binary(str(tmpdir/"seq.truth.bin")) + # seq + if num_seq == 1: + seqs = NumpySequence(X) + else: + seqs = [] + seq_size = nrow//num_seq + for start in range(0, nrow, seq_size): + end = min(start + seq_size, nrow) + seq = NumpySequence(X[start:end]) + seq.batch_size = batch_size + seqs.append(seq) + ds = lgb.Dataset(seqs, label=Y, params=params) + ds.save_binary(str(tmpdir/"seq.seq.bin")) + assert filecmp.cmp(tmpdir/"seq.truth.bin", tmpdir/"seq.seq.bin") + + def test_chunked_dataset(): X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2) From 7406a5848b1908918508a18314ec3439241be7f7 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Sun, 21 Mar 2021 17:18:04 +0800 Subject: [PATCH 03/70] [python-package] example: create Dataset from multiple HDF5 file. --- examples/python-guide/README.md | 3 + .../python-guide/dataset_from_multi_hdf5.py | 102 ++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 examples/python-guide/dataset_from_multi_hdf5.py diff --git a/examples/python-guide/README.md b/examples/python-guide/README.md index 08ded17ab559..a938749bd4c3 100644 --- a/examples/python-guide/README.md +++ b/examples/python-guide/README.md @@ -61,3 +61,6 @@ Examples include: - Plot split value histogram - Plot one specified tree - Plot one specified tree with Graphviz +- [dataset_from_multi_hdf5.py](https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/dataset_from_multi_hdf5.py) + - Construct Dataset from multiple HDF5 file + - Avoids loading all data into memory diff --git a/examples/python-guide/dataset_from_multi_hdf5.py b/examples/python-guide/dataset_from_multi_hdf5.py new file mode 100644 index 000000000000..8f24ad29b6b3 --- /dev/null +++ b/examples/python-guide/dataset_from_multi_hdf5.py @@ -0,0 +1,102 @@ +import h5py +import numpy as np +import pandas as pd + +import lightgbm as lgb + + +class HDFSequence(lgb.Sequence): + def __init__(self, hdf_dataset, batch_size): + """ + Parameters + ---------- + hdf_dataset: h5py.Dataset + dataset in HDF5 file + batch_size: int + when reading data to construct lightgbm Dataset, each read reads batch_size rows + """ + # We can also open HDF5 file once and get access to + self.data = hdf_dataset + self.batch_size = batch_size + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + +def create_dataset_from_multiple_hdf(input_flist, batch_size): + data = [] + ylist = [] + for f in input_flist: + f = h5py.File(f, 'r') + data.append(HDFSequence(f['X'], batch_size)) + ylist.append(f['Y'][:]) + + # params = { + # 'bin_construct_sample_cnt': 200000, + # 'max_bin': 255, + # } + params = None + y = np.concatenate(ylist) + dataset = lgb.Dataset(data, label=y, params=params) + # With binary dataset created, we can use either Python API or cmdline version to train. + # + # Note: in order to create exactly the same dataset with the one created in simple_example.py, we need + # to modify simple_example.py to pass numpy array instead of pandas DataFrame to Dataset constructor. + # The reason is that DataFrame column names will be used in Dataset. For a DataFrame with Int64Index + # as columns, Dataset will use column names like ["0", "1", "2", ...]. While for numpy array, column names + # are using the default one assigned in C++ code (dataset_loader.cpp), like ["Column_0", "Column_1", ...]. + dataset.save_binary('regression.train.from_hdf.bin') + + +def save2hdf(input_data, fname, batch_size): + """Store numpy array to HDF5 file. + + Please note chunk size settings in the implementation for I/O performance optimization. + """ + with h5py.File(fname, 'w') as f: + for name, data in input_data.items(): + nrow, ncol = data.shape + if ncol == 1: + # Y has a single column and we read it in single shot. So store it as an 1-d array. + chunk = (nrow, ) + data = data.values.flatten() + else: + # We use random access for data sampling when creating LightGBM Dataset from Sequence. + # When accessing any element in a HDF5 chunk, it's read entirely. + # To save I/O for sampling, we should keep number of total chunks much larger than sample count. + # Here we are just creating a chunk size that matches with batch_size. + # + # Also note that the data is stored in row major order to avoid extra copy when passing to + # lightgbm Dataset. + chunk = (batch_size, ncol) + f.create_dataset(name, data=data, chunks=chunk, compression='lzf') + + +def generate_hdf(input_fname, output_basename, batch_size): + # Save to 2 HDF5 files for demonstration. + df = pd.read_csv(input_fname, header=None, sep='\t') + + mid = len(df) // 2 + df1 = df.iloc[:mid] + df2 = df.iloc[mid:] + + # We can store multiple dataset inside a single HDF5 file. + # Separating X and Y for choosing best chunk size for data loading. + save2hdf({'Y': df1.iloc[:, :1], 'X': df1.iloc[:, 1:]}, '{}1.h5'.format(output_basename), batch_size) + save2hdf({'Y': df2.iloc[:, :1], 'X': df2.iloc[:, 1:]}, '{}2.h5'.format(output_basename), batch_size) + + +def main(): + batch_size = 64 + generate_hdf('../regression/regression.train', 'regression', 64) + create_dataset_from_multiple_hdf( + ['regression1.h5', 'regression2.h5'], + batch_size=batch_size, + ) + + +if __name__ == '__main__': + main() From 1abaa47519d265f2f964fb848fad8b70a49db360 Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 11:00:16 +0800 Subject: [PATCH 04/70] fix: revert is_class implementation for seq --- python-package/lightgbm/basic.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 4f449fc2ff23..2e4456bdbb0a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -619,6 +619,12 @@ class Sequence(object): batch_size = 4096 # Defaults to read 4K rows in each batch. + @staticmethod + def is_class(obj): + if isinstance(obj, list): + return False + return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") + @abc.abstractmethod def __getitem__(self, idx): # type: (Union[int, slice]) -> np.ndarray """Return data for given row index. @@ -1433,15 +1439,15 @@ def _lazy_init(self, data, label=None, reference=None, self.__init_from_csc(data, params_str, ref_dataset) elif isinstance(data, np.ndarray): self.__init_from_np2d(data, params_str, ref_dataset) - elif isinstance(data, Sequence): - self.__init_from_seqs([data], params_str, ref_dataset) elif isinstance(data, list) and len(data) > 0: if all(isinstance(x, np.ndarray) for x in data): self.__init_from_list_np2d(data, params_str, ref_dataset) - elif all(isinstance(x, Sequence) for x in data): + elif all(Sequence.is_class(x) for x in data): self.__init_from_seqs(data, params_str, ref_dataset) else: raise TypeError('Data list can only be of ndarray or Sequence') + elif Sequence.is_class(data): + self.__init_from_seqs([data], params_str, ref_dataset) elif isinstance(data, dt_DataTable): self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset) else: From 3150ebbc4f46ca500772b0836843322220f6cecc Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 11:07:11 +0800 Subject: [PATCH 05/70] fix: unwanted memory view reference for seq --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 2e4456bdbb0a..887bec421812 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1488,7 +1488,7 @@ def __yield_row_from(self, seqs, indices): seq = seqs[seq_id] id_in_seq = row_id - offset row = seq[int(id_in_seq)] - yield row + yield row if row.flags['OWNDATA'] else row.copy() def __sample(self, seqs, total_nrow): # type: (List[Sequence], int, int) -> (np.ndarray, List[np.ndarray]) From 54628f4ef516d4ba4c51f1957645bca82e6246a2 Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 12:20:18 +0800 Subject: [PATCH 06/70] fix: seq is_class accepts sklearn matrices --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 887bec421812..05301975541e 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -621,7 +621,7 @@ class Sequence(object): @staticmethod def is_class(obj): - if isinstance(obj, list): + if isinstance(obj, list) or hasattr(obj, "getformat"): return False return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") From da538d175c78c1e9e92804b24f43c0cb5ca4b08d Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 14:24:45 +0800 Subject: [PATCH 07/70] fix: requirements for example --- .ci/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.ci/test.sh b/.ci/test.sh index db4427e6dbaa..f9cac622f181 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -226,6 +226,7 @@ import matplotlib\ matplotlib.use\(\"Agg\"\)\ ' plot_example.py # prevent interactive window mode sed -i'.bak' 's/graph.render(view=True)/graph.render(view=False)/' plot_example.py + conda install -q -y -n $CONDA_ENV h5py # requirements for example for f in *.py **/*.py; do python $f || exit -1; done # run all examples cd $BUILD_DIRECTORY/examples/python-guide/notebooks conda install -q -y -n $CONDA_ENV ipywidgets notebook From 0c086b9642c4cc7ad2b4dd893cd565f20c842959 Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 14:32:55 +0800 Subject: [PATCH 08/70] fix: pycode --- python-package/lightgbm/basic.py | 5 ++--- tests/python_package_test/test_basic.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 05301975541e..c4a4c1e67389 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1178,7 +1178,7 @@ def __init__(self, data, label=None, reference=None, self.feature_penalty = None self.monotone_constraints = None self.version = 0 - self._start_row = 0 # Used when pushing rows one by one. + self._start_row = 0 # Used when pushing rows one by one. def __del__(self): try: @@ -1539,10 +1539,9 @@ def __init_from_seqs(self, seqs, params_str, ref_dataset): nrow = len(seq) batch_size = seq.batch_size or Sequence.batch_size for start in range(0, nrow, batch_size): - end = min(start+batch_size, nrow) + end = min(start + batch_size, nrow) self.push_rows(seq[start:end]) - def __init_from_np2d(self, mat, params_str, ref_dataset): """Initialize data from a 2-D numpy matrix.""" if len(mat.shape) != 2: diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 9faab0483037..585feb0855b3 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -12,6 +12,7 @@ from lightgbm.compat import PANDAS_INSTALLED, pd_Series from .utils import load_breast_cancer + def rm_files(files): if not isinstance(files, list): files = [files] @@ -19,6 +20,7 @@ def rm_files(files): if os.path.exists(file): os.remove(file) + def test_basic(tmp_path): X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2) @@ -131,9 +133,9 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ } nrow = 31 - half_nrow = nrow//2 + half_nrow = nrow // 2 ncol = 11 - data = np.arange(nrow*ncol).reshape((nrow, ncol)).astype('float64') + data = np.arange(nrow * ncol).reshape((nrow, ncol)).astype('float64') # total col if include_0: @@ -160,21 +162,21 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ Y = data[:, -1] # truth ds = lgb.Dataset(X, label=Y, params=params) - ds.save_binary(str(tmpdir/"seq.truth.bin")) + ds.save_binary(str(tmpdir / "seq.truth.bin")) # seq if num_seq == 1: seqs = NumpySequence(X) else: seqs = [] - seq_size = nrow//num_seq + seq_size = nrow // num_seq for start in range(0, nrow, seq_size): end = min(start + seq_size, nrow) seq = NumpySequence(X[start:end]) seq.batch_size = batch_size seqs.append(seq) ds = lgb.Dataset(seqs, label=Y, params=params) - ds.save_binary(str(tmpdir/"seq.seq.bin")) - assert filecmp.cmp(tmpdir/"seq.truth.bin", tmpdir/"seq.seq.bin") + ds.save_binary(str(tmpdir / "seq.seq.bin")) + assert filecmp.cmp(tmpdir / "seq.truth.bin", tmpdir / "seq.seq.bin") def test_chunked_dataset(): From e41980e241ba3b29f2217c5fab77106be21d366e Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 14:47:15 +0800 Subject: [PATCH 09/70] feat: print static code linting stage --- .ci/test.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.ci/test.sh b/.ci/test.sh index f9cac622f181..b5c91686ed46 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -65,9 +65,13 @@ if [[ $TASK == "lint" ]]; then "r-lintr>=2.0" pip install --user cpplint isort mypy echo "Linting Python code" + echo "..pycodestyle" pycodestyle --ignore=E501,W503 --exclude=./.nuget,./external_libs . || exit -1 + echo "..pydocstyle" pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1 + echo "..isort" isort . --check-only || exit -1 + echo "..mypy" mypy --ignore-missing-imports python-package/ || true echo "Linting R code" Rscript ${BUILD_DIRECTORY}/.ci/lint_r_code.R ${BUILD_DIRECTORY} || exit -1 From 9cbb292b670a37ae8d7acbe7779149aa3ce25eac Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 14:52:24 +0800 Subject: [PATCH 10/70] fix: linting: avoid shell str regex conversion --- .ci/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/test.sh b/.ci/test.sh index b5c91686ed46..713312940e1c 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -68,7 +68,7 @@ if [[ $TASK == "lint" ]]; then echo "..pycodestyle" pycodestyle --ignore=E501,W503 --exclude=./.nuget,./external_libs . || exit -1 echo "..pydocstyle" - pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1 + pydocstyle --convention=numpy --add-ignore=D105 --match-dir='^(?!^external_libs|test|example).*' --match='(?!^test_|setup).*\.py' . || exit -1 echo "..isort" isort . --check-only || exit -1 echo "..mypy" From afe78577cf2ad1b37c11b5edfac105fae0f41dd0 Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 14:58:09 +0800 Subject: [PATCH 11/70] code style: doc style --- python-package/lightgbm/basic.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index c4a4c1e67389..fcfb9262f1e6 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -598,7 +598,8 @@ def _load_pandas_categorical(file_name=None, model_str=None): class Sequence(object): - """Generic data access interface. + """ + Generic data access interface. Object should support the following operations: @@ -615,12 +616,23 @@ class Sequence(object): With range data access, there's no need to read all data into memory thus reduce memory usage. """ + __metaclass__ = abc.ABCMeta batch_size = 4096 # Defaults to read 4K rows in each batch. @staticmethod def is_class(obj): + """Check if objection is instance of class Sequence. + + Args: + ------- + obj ([any]): object to be checked + + Returns + ------- + [bool]: is Sequence class + """ if isinstance(obj, list) or hasattr(obj, "getformat"): return False return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") @@ -654,7 +666,7 @@ def __getitem__(self, idx): # type: (Union[int, slice]) -> np.ndarray @abc.abstractmethod def __len__(self): # type: () -> int - """Return row count of this sequence """ + """Return row count of this sequence.""" raise NotImplementedError @@ -1492,7 +1504,7 @@ def __yield_row_from(self, seqs, indices): def __sample(self, seqs, total_nrow): # type: (List[Sequence], int, int) -> (np.ndarray, List[np.ndarray]) - """Data Sampling. + """Sample data from seqs. Mimics behavior in c_api.cpp:LGBM_DatasetCreateFromMats() @@ -1524,6 +1536,7 @@ def __init_from_seqs(self, seqs, params_str, ref_dataset): # type: (List[Sequence], str, Dataset) -> None """ Initialize data from a Sequence object. + Sequence: Generic Data Access Object Supports random access and access by batch if properly defined by user """ From edc5a3b0dba1856b9a7150bd9218a2c5bf13510e Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 15:08:10 +0800 Subject: [PATCH 12/70] code style: isort --- python-package/lightgbm/__init__.py | 2 +- python-package/lightgbm/basic.py | 3 +-- tests/python_package_test/test_basic.py | 3 ++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index e429cea80091..dc2143674be2 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -5,7 +5,7 @@ """ import os -from .basic import Booster, Dataset, register_logger, Sequence +from .basic import Booster, Dataset, Sequence, register_logger from .callback import early_stopping, print_evaluation, record_evaluation, reset_parameter from .engine import CVBooster, cv, train diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index fcfb9262f1e6..accf4ad2b68f 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -10,7 +10,7 @@ from functools import wraps from logging import Logger from tempfile import NamedTemporaryFile -from typing import Any, Dict, List, Set, Union, Iterable +from typing import Any, Dict, Iterable, List, Set, Union import numpy as np import scipy.sparse @@ -18,7 +18,6 @@ from .compat import PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_DataFrame, pd_Series from .libpath import find_lib_path - # TODO: how to keep the default values the same with C++ config.h ZERO_THRESHOLD = 1e-35 DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 585feb0855b3..5fe88fd5c451 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -1,7 +1,7 @@ # coding: utf-8 +import filecmp import os -import filecmp import numpy as np import pytest from scipy import sparse @@ -10,6 +10,7 @@ import lightgbm as lgb from lightgbm.compat import PANDAS_INSTALLED, pd_Series + from .utils import load_breast_cancer From 45122fa71567f0be2375cedea5b0327c1fcdaba0 Mon Sep 17 00:00:00 2001 From: Willian Z Date: Mon, 29 Mar 2021 15:44:49 +0800 Subject: [PATCH 13/70] fix ci dependency: h5py on windows --- .ci/test_windows.ps1 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index 20e7bf654a68..0254f5f33980 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -49,7 +49,7 @@ if ($env:TASK -eq "swig") { Exit 0 } -conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy ; Check-Output $? +conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy h5py ; Check-Output $? if ($env:TASK -eq "regular") { mkdir $env:BUILD_SOURCESDIRECTORY/build; cd $env:BUILD_SOURCESDIRECTORY/build From bb3e73d4e99d2109ce9cdf30a2d2ff35f290bb1b Mon Sep 17 00:00:00 2001 From: Willian Z Date: Fri, 16 Apr 2021 10:29:44 +0800 Subject: [PATCH 14/70] [py] remove rm files in test seq https://github.com/microsoft/LightGBM/pull/4089#discussion_r612929623 --- tests/python_package_test/test_basic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 5fe88fd5c451..322c6ac94db1 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -127,8 +127,6 @@ def __len__(self): @pytest.mark.parametrize('include_nan', [False, True]) @pytest.mark.parametrize('num_seq', [1, 3]) def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq): - rm_files(["seq.truth.bin", "seq.seq.bin"]) - params = { "bin_construct_sample_cnt": sample_count, } From 0332d9da35aefe2be44659db148ef980863f1fd3 Mon Sep 17 00:00:00 2001 From: Willian Z Date: Fri, 16 Apr 2021 10:56:43 +0800 Subject: [PATCH 15/70] docs(python): init_from_sample summary https://github.com/microsoft/LightGBM/pull/4089#discussion_r612903389 --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index accf4ad2b68f..c176607f5ed0 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1225,7 +1225,7 @@ def create_sample_indices(self, total_nrow): return indices def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): - """Get the used parameters in the Dataset. + """Create Dataset from sampled data structures. Parameters ---------- From 34f3d74d27e0a57bfda0201ba4bb5701b1d94c99 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 22 Apr 2021 11:12:49 +0800 Subject: [PATCH 16/70] remove dataset dump sample data debugging code. --- include/LightGBM/c_api.h | 4 +-- include/LightGBM/dataset_loader.h | 2 +- src/c_api.cpp | 13 ++++----- src/io/dataset_loader.cpp | 46 +------------------------------ 4 files changed, 9 insertions(+), 56 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 2b0294ec37c1..fd01123c6824 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -222,8 +222,8 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t total_nrow, - const char* parameters, - void* out); + const char* parameters, + void* out); /*! * \brief Create dataset from dense matrix. diff --git a/include/LightGBM/dataset_loader.h b/include/LightGBM/dataset_loader.h index cfcb716e4eb0..e72dd4910804 100644 --- a/include/LightGBM/dataset_loader.h +++ b/include/LightGBM/dataset_loader.h @@ -29,7 +29,7 @@ class DatasetLoader { LIGHTGBM_EXPORT Dataset* ConstructFromSampleData(double** sample_values, int** sample_indices, int num_col, const int* num_per_col, - size_t total_sample_size, data_size_t num_data, const std::string& dump_filename); + size_t total_sample_size, data_size_t num_data); /*! \brief Disable copy */ DatasetLoader& operator=(const DatasetLoader&) = delete; diff --git a/src/c_api.cpp b/src/c_api.cpp index 28bb7c31e944..d65c52a1d59d 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -938,7 +938,7 @@ int LGBM_DatasetCreateFromSampledColumn(double** sample_data, DatasetLoader loader(config, nullptr, 1, nullptr); *out = loader.ConstructFromSampleData(sample_data, sample_indices, ncol, num_per_col, num_sample_row, - static_cast(num_total_row), ""); + static_cast(num_total_row)); API_END(); } @@ -960,7 +960,6 @@ int LGBM_DatasetPushRows(DatasetHandle dataset, int32_t nrow, int32_t ncol, int32_t start_row) { - Log::Debug("DatasetPushRows start_row: %d nrow: %d", start_row, nrow); API_BEGIN(); auto p_dataset = reinterpret_cast(dataset); auto get_row_fun = RowFunctionFromDenseMatric(data, nrow, ncol, data_type, 1); @@ -978,7 +977,6 @@ int LGBM_DatasetPushRows(DatasetHandle dataset, } OMP_THROW_EX(); if (start_row + nrow == p_dataset->num_data()) { - Log::Debug("DatasetPushRows FinishLoad"); p_dataset->FinishLoad(); } API_END(); @@ -1017,7 +1015,6 @@ int LGBM_DatasetPushRowsByCSR(DatasetHandle dataset, API_END(); } - int LGBM_DatasetCreateFromMat(const void* data, int data_type, int32_t nrow, @@ -1125,7 +1122,7 @@ int LGBM_DatasetCreateFromMats(int32_t nmat, Vector2Ptr(&sample_idx).data(), ncol, VectorSize(sample_values).data(), - sample_cnt, total_nrow, "")); + sample_cnt, total_nrow)); } else { ret.reset(new Dataset(total_nrow)); ret->CreateValid( @@ -1204,7 +1201,7 @@ int LGBM_DatasetCreateFromCSR(const void* indptr, Vector2Ptr(&sample_idx).data(), static_cast(num_col), VectorSize(sample_values).data(), - sample_cnt, nrow, "")); + sample_cnt, nrow)); } else { ret.reset(new Dataset(nrow)); ret->CreateValid( @@ -1275,7 +1272,7 @@ int LGBM_DatasetCreateFromCSRFunc(void* get_row_funptr, Vector2Ptr(&sample_idx).data(), static_cast(num_col), VectorSize(sample_values).data(), - sample_cnt, nrow, "")); + sample_cnt, nrow)); } else { ret.reset(new Dataset(nrow)); ret->CreateValid( @@ -1351,7 +1348,7 @@ int LGBM_DatasetCreateFromCSC(const void* col_ptr, Vector2Ptr(&sample_idx).data(), static_cast(sample_values.size()), VectorSize(sample_values).data(), - sample_cnt, nrow, "")); + sample_cnt, nrow)); } else { ret.reset(new Dataset(nrow)); ret->CreateValid( diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index 87aac784533d..2d2c4d622b1c 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -617,38 +617,10 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b return dataset.release(); } -// To help verify whether sample data is changed when using different language bindings. -static void DumpSampleData(const std::string& sample_filename, double** sample_values, - int** sample_indices, int num_col, const int* num_per_col, - size_t total_sample_size, data_size_t num_data) { - Log::Info("dump sample data to %s", sample_filename.c_str()); - std::ofstream out(sample_filename); - - out << "num_col: " << num_col << "\n"; - out << "total_sample_size: " << total_sample_size << "\n"; - out << "num_data: " << num_data << "\n"; - - out << "num_per_col:\n"; - for (int i = 0; i < num_col; ++i) { - out << " c:" << i << "=" << num_per_col[i] << "\n"; - } - - out << "sample data:\n"; - for (int i = 0; i < num_col; ++i) { - out << " c:" << i << "\n"; - for (int j = 0; j < num_per_col[i]; ++j) { - out << " r:" << sample_indices[i][j] << "=" << sample_values[i][j] << "\n"; - } - } - out << "\n"; -} Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values, int** sample_indices, int num_col, const int* num_per_col, - size_t total_sample_size, data_size_t num_data, const std::string& dump_filename) { - if (dump_filename != "") { - DumpSampleData(dump_filename, sample_values, sample_indices, num_col, num_per_col, total_sample_size, num_data); - } + size_t total_sample_size, data_size_t num_data) { CheckSampleSize(total_sample_size, static_cast(num_data)); int num_total_features = num_col; if (Network::num_machines() > 1) { @@ -1004,22 +976,6 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, dataset->feature_groups_.clear(); dataset->num_total_features_ = std::max(static_cast(sample_values.size()), parser->NumFeatures()); - - /* - int num_col = static_cast(sample_values.size()); - std::vector sample_values_ptr(num_col); - std::vector sample_indices_ptr(num_col); - std::vector num_per_col(num_col); - for (size_t i = 0; i < sample_values.size(); ++i) { - sample_values_ptr[i] = sample_values[i].data(); - sample_indices_ptr[i] = sample_indices[i].data(); - num_per_col[i] = static_cast(sample_indices[i].size()); - } - DumpSampleData("lgbm_sample_from_text.txt", sample_values_ptr.data(), - sample_indices_ptr.data(), num_col, num_per_col.data(), - sample_data.size(), dataset->num_data()); - */ - if (num_machines > 1) { dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_); } From 2ba0261c9e25a0dce62a3402759e068c98194e01 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 27 Apr 2021 08:41:08 +0800 Subject: [PATCH 17/70] remove typo fix. Create separate PR for this. --- src/io/dataset.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index e5cabe682caa..51879351178e 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -932,7 +932,7 @@ bool Dataset::GetIntField(const char* field_name, data_size_t* out_len, void Dataset::SaveBinaryFile(const char* bin_filename) { if (bin_filename != nullptr && std::string(bin_filename) == data_filename_) { - Log::Warning("Binary file %s already exists", bin_filename); + Log::Warning("Bianry file %s already exists", bin_filename); return; } // if not pass a filename, just append ".bin" of original file From 51ba5b0307ef684d4f743455a2d328bc14ffb635 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 29 Apr 2021 08:02:55 +0800 Subject: [PATCH 18/70] fix typo in src/c_api.cpp Co-authored-by: James Lamb --- src/c_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index d65c52a1d59d..c2a7ccebdbeb 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1049,7 +1049,7 @@ int LGBM_SampleIndices(int32_t total_nrow, // Sample count, random seed etc. should be provided in parameters. API_BEGIN(); if (out == nullptr) { - Log::Fatal("sample indicies output is nullptr"); + Log::Fatal("sample indices output is nullptr"); } auto param = Config::Str2Map(parameters); Config config; From 83db0a74bd132c665cd4aa01bd321aedcac22dc5 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Fri, 30 Apr 2021 16:42:59 +0800 Subject: [PATCH 19/70] style(linting): py3 type hint for seq --- python-package/lightgbm/basic.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index c176607f5ed0..e6d5b6ce144b 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -10,7 +10,7 @@ from functools import wraps from logging import Logger from tempfile import NamedTemporaryFile -from typing import Any, Dict, Iterable, List, Set, Union +from typing import Any, Dict, Iterable, List, Set, Union, Tuple import numpy as np import scipy.sparse @@ -637,7 +637,7 @@ def is_class(obj): return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") @abc.abstractmethod - def __getitem__(self, idx): # type: (Union[int, slice]) -> np.ndarray + def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: """Return data for given row index. A basic implementation should look like this: @@ -664,7 +664,7 @@ def __getitem__(self, idx): # type: (Union[int, slice]) -> np.ndarray raise NotImplementedError("remove this line if subclassing") @abc.abstractmethod - def __len__(self): # type: () -> int + def __len__(self) -> int: """Return row count of this sequence.""" raise NotImplementedError @@ -1486,8 +1486,7 @@ def _lazy_init(self, data, label=None, reference=None, # set feature names return self.set_feature_name(feature_name) - def __yield_row_from(self, seqs, indices): - # type: (List[Sequence], Iterable[int]) -> ... + def __yield_row_from(self, seqs: List[Sequence], indices: Iterable[int]): offset = 0 seq_id = 0 seq = seqs[seq_id] @@ -1501,8 +1500,7 @@ def __yield_row_from(self, seqs, indices): row = seq[int(id_in_seq)] yield row if row.flags['OWNDATA'] else row.copy() - def __sample(self, seqs, total_nrow): - # type: (List[Sequence], int, int) -> (np.ndarray, List[np.ndarray]) + def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[np.ndarray, List[np.ndarray]]: """Sample data from seqs. Mimics behavior in c_api.cpp:LGBM_DatasetCreateFromMats() @@ -1531,8 +1529,7 @@ def __sample(self, seqs, total_nrow): return filtered, filtered_idx - def __init_from_seqs(self, seqs, params_str, ref_dataset): - # type: (List[Sequence], str, Dataset) -> None + def __init_from_seqs(self, seqs: List[Sequence], params_str: str, ref_dataset: 'Dataset'): """ Initialize data from a Sequence object. From 989188195674dff99433917c642d54b7f1220039 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Fri, 30 Apr 2021 17:15:15 +0800 Subject: [PATCH 20/70] test(basic): os.path style path handling --- tests/python_package_test/test_basic.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 322c6ac94db1..0266a876ae83 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -161,7 +161,9 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ Y = data[:, -1] # truth ds = lgb.Dataset(X, label=Y, params=params) - ds.save_binary(str(tmpdir / "seq.truth.bin")) + + ds.save_binary(os.path.join(tmpdir, "seq.truth.bin")) + # seq if num_seq == 1: seqs = NumpySequence(X) @@ -174,8 +176,9 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ seq.batch_size = batch_size seqs.append(seq) ds = lgb.Dataset(seqs, label=Y, params=params) - ds.save_binary(str(tmpdir / "seq.seq.bin")) - assert filecmp.cmp(tmpdir / "seq.truth.bin", tmpdir / "seq.seq.bin") + ds.save_binary(os.path.join(tmpdir, "seq.seq.bin")) + + assert filecmp.cmp(os.path.join(tmpdir, "seq.truth.bin"), os.path.join(tmpdir, "seq.seq.bin")) def test_chunked_dataset(): From 675ed62f299459c370d4c23d4955844b9095a190 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Fri, 30 Apr 2021 17:20:09 +0800 Subject: [PATCH 21/70] Revert "feat: print static code linting stage" This reverts commit 10bd79f7f8258bea8e61c3abb8c9c7e4456a916d. --- .ci/test.sh | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 713312940e1c..f9cac622f181 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -65,13 +65,9 @@ if [[ $TASK == "lint" ]]; then "r-lintr>=2.0" pip install --user cpplint isort mypy echo "Linting Python code" - echo "..pycodestyle" pycodestyle --ignore=E501,W503 --exclude=./.nuget,./external_libs . || exit -1 - echo "..pydocstyle" - pydocstyle --convention=numpy --add-ignore=D105 --match-dir='^(?!^external_libs|test|example).*' --match='(?!^test_|setup).*\.py' . || exit -1 - echo "..isort" + pydocstyle --convention=numpy --add-ignore=D105 --match-dir="^(?!^external_libs|test|example).*" --match="(?!^test_|setup).*\.py" . || exit -1 isort . --check-only || exit -1 - echo "..mypy" mypy --ignore-missing-imports python-package/ || true echo "Linting R code" Rscript ${BUILD_DIRECTORY}/.ci/lint_r_code.R ${BUILD_DIRECTORY} || exit -1 From 1d53568c3db7f514243a669fffd48c43f1792e5a Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Thu, 6 May 2021 16:55:21 +0800 Subject: [PATCH 22/70] feat(python): sequence on validation set --- python-package/lightgbm/basic.py | 28 +++++++++++++++++++----- tests/python_package_test/test_basic.py | 29 ++++++++++++++++++++----- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index e6d5b6ce144b..21dc36dbde75 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1224,6 +1224,15 @@ def create_sample_indices(self, total_nrow): )) return indices + def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset'): + self.handle = ctypes.c_void_p() + _safe_call(_LIB.LGBM_DatasetCreateByReference( + ref_dataset, + ctypes.c_int64(total_nrow), + ctypes.byref(self.handle), + )) + return self + def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): """Create Dataset from sampled data structures. @@ -1537,12 +1546,21 @@ def __init_from_seqs(self, seqs: List[Sequence], params_str: str, ref_dataset: ' Supports random access and access by batch if properly defined by user """ total_nrow = sum(len(seq) for seq in seqs) - ncol = len(seqs[0][0]) - sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT - sample_cnt = min(sample_cnt, total_nrow) - sample_data, col_indices = self.__sample(seqs, total_nrow) - self.init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) + # check uniformity: + # ncol = len(seqs[0][0]) + + # create validation dataset from ref_dataset + if ref_dataset: + if self.params.get("bin_construct_sample_cnt"): + _log_warning('Option `bin_construct_sample_cnt` will be ignored when creating validation dataset.') + self.init_from_ref_dataset(total_nrow, ref_dataset) + else: + sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT + sample_cnt = min(sample_cnt, total_nrow) + + sample_data, col_indices = self.__sample(seqs, total_nrow) + self.init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) for seq in seqs: nrow = len(seq) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 0266a876ae83..b2d40bcacbe4 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -126,7 +126,8 @@ def __len__(self): @pytest.mark.parametrize('include_0', [False, True]) @pytest.mark.parametrize('include_nan', [False, True]) @pytest.mark.parametrize('num_seq', [1, 3]) -def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq): +@pytest.mark.parametrize('create_valid', [False, True]) +def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq, create_valid): params = { "bin_construct_sample_cnt": sample_count, } @@ -135,6 +136,11 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ half_nrow = nrow // 2 ncol = 11 data = np.arange(nrow * ncol).reshape((nrow, ncol)).astype('float64') + + if create_valid: + # select some head and tail rows + ref_data = data[[0,1,2,-1,-2],:] + ref_dataset = lgb.Dataset(ref_data[:, :-1], label=ref_data[:, -1], params=params) # total col if include_0: @@ -159,8 +165,12 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ # X, Y split X = data[:, :-1] Y = data[:, -1] - # truth - ds = lgb.Dataset(X, label=Y, params=params) + + if create_valid: + ds = lgb.Dataset(X, label=Y, reference=ref_dataset) + else: + # truth + ds = lgb.Dataset(X, label=Y, params=params) ds.save_binary(os.path.join(tmpdir, "seq.truth.bin")) @@ -175,10 +185,19 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ seq = NumpySequence(X[start:end]) seq.batch_size = batch_size seqs.append(seq) - ds = lgb.Dataset(seqs, label=Y, params=params) + + if create_valid: + ds = lgb.Dataset(seqs, label=Y, reference=ref_dataset) + else: + ds = lgb.Dataset(seqs, label=Y, params=params) ds.save_binary(os.path.join(tmpdir, "seq.seq.bin")) - assert filecmp.cmp(os.path.join(tmpdir, "seq.truth.bin"), os.path.join(tmpdir, "seq.seq.bin")) + if create_valid: + # TODO: verify validation dataset somehow + # Some metadata are not initialized while validation dataset are constructed + ... + else: + assert filecmp.cmp(os.path.join(tmpdir, "seq.truth.bin"), os.path.join(tmpdir, "seq.seq.bin")) def test_chunked_dataset(): From ea6560c9728280e0bc14b72a7c0d58d276d4b58d Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Thu, 6 May 2021 16:55:42 +0800 Subject: [PATCH 23/70] minor(python): comment --- python-package/lightgbm/basic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 21dc36dbde75..8cb967c16c83 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1544,6 +1544,8 @@ def __init_from_seqs(self, seqs: List[Sequence], params_str: str, ref_dataset: ' Sequence: Generic Data Access Object Supports random access and access by batch if properly defined by user + + Data scheme uniformity are trusted, not checked """ total_nrow = sum(len(seq) for seq in seqs) From b07c33cde18fde36116759a45dfdbba54bbe42a7 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Thu, 6 May 2021 16:56:26 +0800 Subject: [PATCH 24/70] minor(python): test option hint --- tests/python_package_test/test_basic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index b2d40bcacbe4..0a03893cd04d 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -130,6 +130,7 @@ def __len__(self): def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq, create_valid): params = { "bin_construct_sample_cnt": sample_count, + # "data_random_seed": 0, } nrow = 31 From 206c8bf1b59e13ebe880b91c0618a4a0042a63a6 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Thu, 6 May 2021 17:12:31 +0800 Subject: [PATCH 25/70] style(python): fix code linting --- python-package/lightgbm/basic.py | 6 +++--- tests/python_package_test/test_basic.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 8cb967c16c83..4684842faa4c 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -10,7 +10,7 @@ from functools import wraps from logging import Logger from tempfile import NamedTemporaryFile -from typing import Any, Dict, Iterable, List, Set, Union, Tuple +from typing import Any, Dict, Iterable, List, Set, Tuple, Union import numpy as np import scipy.sparse @@ -1545,11 +1545,11 @@ def __init_from_seqs(self, seqs: List[Sequence], params_str: str, ref_dataset: ' Sequence: Generic Data Access Object Supports random access and access by batch if properly defined by user - Data scheme uniformity are trusted, not checked + Data scheme uniformity are trusted, not checked """ total_nrow = sum(len(seq) for seq in seqs) - # check uniformity: + # check uniformity: # ncol = len(seqs[0][0]) # create validation dataset from ref_dataset diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 0a03893cd04d..02a1e8d11f63 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -137,10 +137,10 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ half_nrow = nrow // 2 ncol = 11 data = np.arange(nrow * ncol).reshape((nrow, ncol)).astype('float64') - + if create_valid: # select some head and tail rows - ref_data = data[[0,1,2,-1,-2],:] + ref_data = data[[0, 1, 2, -1, -2], :] ref_dataset = lgb.Dataset(ref_data[:, :-1], label=ref_data[:, -1], params=params) # total col From 4c8f04d7e3d0c8bd9d41765b270f6b510652b6a3 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Thu, 6 May 2021 17:20:09 +0800 Subject: [PATCH 26/70] style(python): add pydoc for ref_dataset --- python-package/lightgbm/basic.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 4684842faa4c..5162f1f23535 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1225,6 +1225,18 @@ def create_sample_indices(self, total_nrow): return indices def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset'): + """Create dataset from a reference dataset. + + Args + ---------- + total_nrow (int): number of rows expected to add to dataset + ref_dataset (Dataset): referance dataset to extract meta from + + Returns + ------- + self : Dataset + Constructed Dataset object. + """ self.handle = ctypes.c_void_p() _safe_call(_LIB.LGBM_DatasetCreateByReference( ref_dataset, From 4639b17eb896a50eed9355fc92b8fa899376dbff Mon Sep 17 00:00:00 2001 From: Willian Z Date: Fri, 7 May 2021 13:31:55 +0800 Subject: [PATCH 27/70] doc(python): sequence Co-authored-by: shiyu1994 --- examples/python-guide/dataset_from_multi_hdf5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/python-guide/dataset_from_multi_hdf5.py b/examples/python-guide/dataset_from_multi_hdf5.py index 8f24ad29b6b3..a425e30106f1 100644 --- a/examples/python-guide/dataset_from_multi_hdf5.py +++ b/examples/python-guide/dataset_from_multi_hdf5.py @@ -83,7 +83,7 @@ def generate_hdf(input_fname, output_basename, batch_size): df1 = df.iloc[:mid] df2 = df.iloc[mid:] - # We can store multiple dataset inside a single HDF5 file. + # We can store multiple datasets inside a single HDF5 file. # Separating X and Y for choosing best chunk size for data loading. save2hdf({'Y': df1.iloc[:, :1], 'X': df1.iloc[:, 1:]}, '{}1.h5'.format(output_basename), batch_size) save2hdf({'Y': df2.iloc[:, :1], 'X': df2.iloc[:, 1:]}, '{}2.h5'.format(output_basename), batch_size) From c6579ab74da76cbf92d17e01b194d0fb6e7bcc02 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Fri, 7 May 2021 13:35:14 +0800 Subject: [PATCH 28/70] revert(python): sequence class abc --- python-package/lightgbm/basic.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 5162f1f23535..ef40cc16c87b 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1,6 +1,5 @@ # coding: utf-8 """Wrapper for C API of LightGBM.""" -import abc import ctypes import json import os @@ -616,8 +615,6 @@ class Sequence(object): reduce memory usage. """ - __metaclass__ = abc.ABCMeta - batch_size = 4096 # Defaults to read 4K rows in each batch. @staticmethod @@ -636,7 +633,6 @@ def is_class(obj): return False return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") - @abc.abstractmethod def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: """Return data for given row index. @@ -663,7 +659,6 @@ def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: """ raise NotImplementedError("remove this line if subclassing") - @abc.abstractmethod def __len__(self) -> int: """Return row count of this sequence.""" raise NotImplementedError From bdf80295e361fa9510f668410c5d6c78d57cce14 Mon Sep 17 00:00:00 2001 From: Willian Zhang Date: Fri, 7 May 2021 15:52:27 +0800 Subject: [PATCH 29/70] chore(python): remove rm_files --- tests/python_package_test/test_basic.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 02a1e8d11f63..5933509ba2b3 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -14,14 +14,6 @@ from .utils import load_breast_cancer -def rm_files(files): - if not isinstance(files, list): - files = [files] - for file in files: - if os.path.exists(file): - os.remove(file) - - def test_basic(tmp_path): X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, random_state=2) From 88978e19ac71e9dab74b74b9752a8cb27e4c695f Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 7 May 2021 18:42:26 +0800 Subject: [PATCH 30/70] Remove useless static_assert. --- src/c_api.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index c2a7ccebdbeb..1bd7d55e48d2 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1056,8 +1056,6 @@ int LGBM_SampleIndices(int32_t total_nrow, config.Set(param); auto sample_indices = CreateSampleIndices(config, total_nrow); - - static_assert (sizeof(int) == 4, "int size is not 4"); memcpy(out, sample_indices.data(), sizeof(int32_t) * sample_indices.size()); API_END(); From 97367f4fd6ee700d2d72ca65d2d35f47403fc103 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 13 May 2021 16:31:22 +0800 Subject: [PATCH 31/70] refactor: test_basic test for sequence. --- python-package/lightgbm/basic.py | 5 -- tests/python_package_test/test_basic.py | 104 +++++++++++++----------- 2 files changed, 56 insertions(+), 53 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index ef40cc16c87b..001f01c7fc46 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1556,13 +1556,8 @@ def __init_from_seqs(self, seqs: List[Sequence], params_str: str, ref_dataset: ' """ total_nrow = sum(len(seq) for seq in seqs) - # check uniformity: - # ncol = len(seqs[0][0]) - # create validation dataset from ref_dataset if ref_dataset: - if self.params.get("bin_construct_sample_cnt"): - _log_warning('Option `bin_construct_sample_cnt` will be ignored when creating validation dataset.') self.init_from_ref_dataset(total_nrow, ref_dataset) else: sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 5933509ba2b3..04624b24bad3 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -91,14 +91,15 @@ def test_basic(tmp_path): class NumpySequence(lgb.Sequence): - def __init__(self, ndarray): + def __init__(self, ndarray, batch_size): self.ndarray = ndarray + self.batch_size = batch_size def __getitem__(self, idx): # The simple implementation is just a single "return self.ndarray[idx]" # The following is for demo and testing purpose. if isinstance(idx, int): - return self.__get_one_line__(idx) + return self._get_one_line(idx) elif isinstance(idx, slice): if not (idx.step is None or idx.step is 1): raise NotImplementedError("No need to implement, caller will not set step by now") @@ -106,34 +107,39 @@ def __getitem__(self, idx): else: raise TypeError("Sequence Index must be an integer/list/slice, got {}".format(type(idx))) - def __get_one_line__(self, idx): + def _get_one_line(self, idx): return self.ndarray[idx] def __len__(self): return len(self.ndarray) -@pytest.mark.parametrize('sample_count', [2, 5]) +def _create_sequence_from_ndarray(data, num_seq, batch_size): + if num_seq == 1: + return NumpySequence(data, batch_size) + + nrow = data.shape[0] + seqs = [] + seq_size = nrow // num_seq + for start in range(0, nrow, seq_size): + end = min(start + seq_size, nrow) + seq = NumpySequence(data[start:end], batch_size) + seqs.append(seq) + return seqs + + +@pytest.mark.parametrize('sample_count', [11, 23, 100, None]) @pytest.mark.parametrize('batch_size', [3, 20, None]) @pytest.mark.parametrize('include_0', [False, True]) @pytest.mark.parametrize('include_nan', [False, True]) @pytest.mark.parametrize('num_seq', [1, 3]) -@pytest.mark.parametrize('create_valid', [False, True]) -def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq, create_valid): - params = { - "bin_construct_sample_cnt": sample_count, - # "data_random_seed": 0, - } +def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq): + params = { 'bin_construct_sample_cnt': sample_count } - nrow = 31 + nrow = 50 half_nrow = nrow // 2 ncol = 11 - data = np.arange(nrow * ncol).reshape((nrow, ncol)).astype('float64') - - if create_valid: - # select some head and tail rows - ref_data = data[[0, 1, 2, -1, -2], :] - ref_dataset = lgb.Dataset(ref_data[:, :-1], label=ref_data[:, -1], params=params) + data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol)) # total col if include_0: @@ -155,42 +161,44 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ if include_0: data[half_nrow:-2, 4] = 0 - # X, Y split X = data[:, :-1] Y = data[:, -1] - if create_valid: - ds = lgb.Dataset(X, label=Y, reference=ref_dataset) - else: - # truth - ds = lgb.Dataset(X, label=Y, params=params) + npy_bin_fname = os.path.join(tmpdir, 'data_from_npy.bin') + seq_bin_fname = os.path.join(tmpdir, 'data_from_seq.bin') - ds.save_binary(os.path.join(tmpdir, "seq.truth.bin")) + # Create dataset from numpy array directly. + ds = lgb.Dataset(X, label=Y, params=params) + ds.save_binary(npy_bin_fname) - # seq - if num_seq == 1: - seqs = NumpySequence(X) - else: - seqs = [] - seq_size = nrow // num_seq - for start in range(0, nrow, seq_size): - end = min(start + seq_size, nrow) - seq = NumpySequence(X[start:end]) - seq.batch_size = batch_size - seqs.append(seq) - - if create_valid: - ds = lgb.Dataset(seqs, label=Y, reference=ref_dataset) - else: - ds = lgb.Dataset(seqs, label=Y, params=params) - ds.save_binary(os.path.join(tmpdir, "seq.seq.bin")) - - if create_valid: - # TODO: verify validation dataset somehow - # Some metadata are not initialized while validation dataset are constructed - ... - else: - assert filecmp.cmp(os.path.join(tmpdir, "seq.truth.bin"), os.path.join(tmpdir, "seq.seq.bin")) + # Create dataset using Sequence. + seqs = _create_sequence_from_ndarray(X, num_seq, batch_size) + seq_ds = lgb.Dataset(seqs, label=Y, params=params) + seq_ds.save_binary(seq_bin_fname) + + assert filecmp.cmp(npy_bin_fname, seq_bin_fname) + + # Test for validation set. + # Select some random rows as valid data. + rng = np.random.default_rng() # Pass integer to set seed when needed. + valid_idx = (rng.random(10) * nrow).astype(np.int) + valid_data = data[valid_idx, :] + valid_X = valid_data[:, :-1] + valid_Y = valid_data[:, -1] + + valid_npy_bin_fname = os.path.join(tmpdir, 'valid_data_from_npy.bin') + valid_seq_bin_fname = os.path.join(tmpdir, 'valid_data_from_seq.bin') + + valid_ds = lgb.Dataset(valid_X, label=valid_Y, params=params, reference=ds) + valid_ds.save_binary(valid_npy_bin_fname) + valid_ds._dump_text(os.path.join(tmpdir, 'valid_numpy.txt')) + + valid_seqs = _create_sequence_from_ndarray(valid_X, num_seq, batch_size) + valid_seq_ds = lgb.Dataset(valid_seqs, label=valid_Y, params=params, reference=valid_ds) + valid_seq_ds.save_binary(valid_seq_bin_fname) + valid_seq_ds._dump_text(os.path.join(tmpdir, 'valid_seq.txt')) + + assert filecmp.cmp(valid_npy_bin_fname, valid_seq_bin_fname) def test_chunked_dataset(): From fb9368756b280beb7b7a31e4719b0d59271f484c Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 13 May 2021 19:54:14 +0800 Subject: [PATCH 32/70] fix lint complaint. --- tests/python_package_test/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 04624b24bad3..222763d25c29 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -134,7 +134,7 @@ def _create_sequence_from_ndarray(data, num_seq, batch_size): @pytest.mark.parametrize('include_nan', [False, True]) @pytest.mark.parametrize('num_seq', [1, 3]) def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq): - params = { 'bin_construct_sample_cnt': sample_count } + params = {'bin_construct_sample_cnt': sample_count} nrow = 50 half_nrow = nrow // 2 From 9ebaa65956dbe3b55772952581d686dce0c410a8 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 14 May 2021 09:07:24 +0800 Subject: [PATCH 33/70] remove dataset._dump_text in sequence test. --- tests/python_package_test/test_basic.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 222763d25c29..2b3b3d929fa9 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -191,12 +191,10 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ valid_ds = lgb.Dataset(valid_X, label=valid_Y, params=params, reference=ds) valid_ds.save_binary(valid_npy_bin_fname) - valid_ds._dump_text(os.path.join(tmpdir, 'valid_numpy.txt')) valid_seqs = _create_sequence_from_ndarray(valid_X, num_seq, batch_size) valid_seq_ds = lgb.Dataset(valid_seqs, label=valid_Y, params=params, reference=valid_ds) valid_seq_ds.save_binary(valid_seq_bin_fname) - valid_seq_ds._dump_text(os.path.join(tmpdir, 'valid_seq.txt')) assert filecmp.cmp(valid_npy_bin_fname, valid_seq_bin_fname) From d9d2eab29c7f80ecf10316c38695eec942993f40 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Sat, 15 May 2021 19:47:01 +0800 Subject: [PATCH 34/70] Fix reverting typo fix. --- src/io/dataset.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index 51879351178e..e5cabe682caa 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -932,7 +932,7 @@ bool Dataset::GetIntField(const char* field_name, data_size_t* out_len, void Dataset::SaveBinaryFile(const char* bin_filename) { if (bin_filename != nullptr && std::string(bin_filename) == data_filename_) { - Log::Warning("Bianry file %s already exists", bin_filename); + Log::Warning("Binary file %s already exists", bin_filename); return; } // if not pass a filename, just append ".bin" of original file From 7512e1b5b53b426e2df9e2009996b3a2d874b8ee Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 24 May 2021 09:56:52 +0800 Subject: [PATCH 35/70] Apply suggestions from code review Co-authored-by: James Lamb --- python-package/lightgbm/basic.py | 45 ++++++++++++++++++++------------ src/c_api.cpp | 4 +-- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 001f01c7fc46..33f3d23bb9dc 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -595,7 +595,7 @@ def _load_pandas_categorical(file_name=None, model_str=None): return None -class Sequence(object): +class Sequence: """ Generic data access interface. @@ -618,7 +618,7 @@ class Sequence(object): batch_size = 4096 # Defaults to read 4K rows in each batch. @staticmethod - def is_class(obj): + def is_class(obj) -> bool: """Check if objection is instance of class Sequence. Args: @@ -1192,13 +1192,16 @@ def __del__(self): except AttributeError: pass - def create_sample_indices(self, total_nrow): - """Create sample indices for the given parameter of the Dataset. + def create_sample_indices(self, total_nrow: int) -> np.ndarray: + """Get an array of randomly chosen indices from this ``Dataset``. + + Indices are sampled without replacement. Parameters ---------- total_nrow: int Total number of rows to sample from. + If this value is greater than the value of parameter ``bin_construct_sample_cnt``, only ``bin_construct_sample_cnt`` indices will be used. If Dataset has multiple input data, this should be the sum of rows of every file. Returns @@ -1207,7 +1210,7 @@ def create_sample_indices(self, total_nrow): Indices for sampled data. """ param_str = param_dict_to_str(self.params) - sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT + sample_cnt = self.params.get("bin_construct_sample_cnt", DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT) sample_cnt = min(sample_cnt, total_nrow) indices = np.zeros(sample_cnt, dtype=np.int32) ptr_data, _, _ = c_int_array(indices) @@ -1219,18 +1222,18 @@ def create_sample_indices(self, total_nrow): )) return indices - def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset'): + def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dataset': """Create dataset from a reference dataset. - Args - ---------- + Parameters + --------------- total_nrow (int): number of rows expected to add to dataset ref_dataset (Dataset): referance dataset to extract meta from Returns ------- - self : Dataset - Constructed Dataset object. + self : Dataset + Constructed Dataset object. """ self.handle = ctypes.c_void_p() _safe_call(_LIB.LGBM_DatasetCreateByReference( @@ -1240,7 +1243,13 @@ def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset'): )) return self - def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): + def init_from_sample( + self, + sample_data: List[np.array[float64]], + sample_indices: List[np.array[int]], + sample_cnt: int, + total_nrow: int + ) -> "Dataset": """Create Dataset from sampled data structures. Parameters @@ -1259,9 +1268,9 @@ def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): self : Dataset Constructed Dataset object. """ - assert len(sample_data) == len(sample_indices), "#sample data column != #column indices" - ncol = len(sample_indices) + assert len(sample_data) == ncol, "#sample data column != #column indices" + for i in range(ncol): if sample_data[i].dtype != np.double: @@ -1296,11 +1305,13 @@ def init_from_sample(self, sample_data, sample_indices, sample_cnt, total_nrow): )) return self - def push_rows(self, data): + def push_rows(self, data: np.ndarray) -> 'Dataset': """Add rows to Dataset. - Args: - data: numpy 1-D array + Parameters + ---------- + data : numpy 1-D array + New data to add to the Dataset. Returns ------- @@ -1560,7 +1571,7 @@ def __init_from_seqs(self, seqs: List[Sequence], params_str: str, ref_dataset: ' if ref_dataset: self.init_from_ref_dataset(total_nrow, ref_dataset) else: - sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT + sample_cnt = self.params.get("bin_construct_sample_cnt", DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT) sample_cnt = min(sample_cnt, total_nrow) sample_data, col_indices = self.__sample(seqs, total_nrow) diff --git a/src/c_api.cpp b/src/c_api.cpp index 1bd7d55e48d2..95cfe1db8085 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -1043,8 +1043,8 @@ static inline std::vector CreateSampleIndices(const Config& config, int int LGBM_SampleIndices(int32_t total_nrow, - const char* parameters, - void* out) { + const char* parameters, + void* out) { // This API is to keep python binding's behavior the same with C++ implementation. // Sample count, random seed etc. should be provided in parameters. API_BEGIN(); From 39ff868ec075cebfb90b00beb6cf5ceb579036f3 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Wed, 26 May 2021 17:11:13 +0800 Subject: [PATCH 36/70] Fix type hint, code and doc style. --- python-package/lightgbm/basic.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 33f3d23bb9dc..792deee24b85 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1194,7 +1194,7 @@ def __del__(self): def create_sample_indices(self, total_nrow: int) -> np.ndarray: """Get an array of randomly chosen indices from this ``Dataset``. - + Indices are sampled without replacement. Parameters @@ -1226,7 +1226,7 @@ def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dat """Create dataset from a reference dataset. Parameters - --------------- + ---------- total_nrow (int): number of rows expected to add to dataset ref_dataset (Dataset): referance dataset to extract meta from @@ -1245,22 +1245,22 @@ def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dat def init_from_sample( self, - sample_data: List[np.array[float64]], - sample_indices: List[np.array[int]], + sample_data: List[np.ndarray], + sample_indices: List[np.ndarray], sample_cnt: int, - total_nrow: int + total_nrow: int, ) -> "Dataset": """Create Dataset from sampled data structures. Parameters ---------- - sample_data: List[np.array[float64]] + sample_data: Sample data for each column - sample_indices: List[np.array[int]] + sample_indices: Sample data row index for each column. - sample_cnt: int + sample_cnt: Number of samples. - total_nrow: int + total_nrow: Total number of rows for all input file. Returns @@ -1271,10 +1271,9 @@ def init_from_sample( ncol = len(sample_indices) assert len(sample_data) == ncol, "#sample data column != #column indices" - for i in range(ncol): if sample_data[i].dtype != np.double: - raise ValueError("sample data type {} is not double".format(sample_data.dtype)) + raise ValueError("sample_data[{}] type {} is not double".format(i, sample_data[i].dtype)) if sample_indices[i].dtype != np.int32: raise ValueError("sample_indices[{}] type {} is not int32".format(i, sample_indices[i].dtype)) @@ -1527,7 +1526,7 @@ def __yield_row_from(self, seqs: List[Sequence], indices: Iterable[int]): row = seq[int(id_in_seq)] yield row if row.flags['OWNDATA'] else row.copy() - def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[np.ndarray, List[np.ndarray]]: + def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]: """Sample data from seqs. Mimics behavior in c_api.cpp:LGBM_DatasetCreateFromMats() @@ -1539,8 +1538,7 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[np.ndarray, L indices = self.create_sample_indices(total_nrow) # Select sampled rows, transpose to column order. - sampled = [row for row in self.__yield_row_from(seqs, indices)] - sampled = np.array(sampled) + sampled = np.array(row for row in self.__yield_row_from(seqs, indices)) sampled = sampled.T filtered = [] From 4a50382f04f194b613d77702d94c06d01e1516ab Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Wed, 26 May 2021 17:53:38 +0800 Subject: [PATCH 37/70] fix failing test_basic. --- python-package/lightgbm/basic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 792deee24b85..2c7fbbf7c502 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1210,7 +1210,8 @@ def create_sample_indices(self, total_nrow: int) -> np.ndarray: Indices for sampled data. """ param_str = param_dict_to_str(self.params) - sample_cnt = self.params.get("bin_construct_sample_cnt", DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT) + # Note self.params may contain 'bin_construct_sample_cnt' but is None. + sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT sample_cnt = min(sample_cnt, total_nrow) indices = np.zeros(sample_cnt, dtype=np.int32) ptr_data, _, _ = c_int_array(indices) @@ -1538,7 +1539,7 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr indices = self.create_sample_indices(total_nrow) # Select sampled rows, transpose to column order. - sampled = np.array(row for row in self.__yield_row_from(seqs, indices)) + sampled = np.array([row for row in self.__yield_row_from(seqs, indices)]) sampled = sampled.T filtered = [] @@ -1569,7 +1570,7 @@ def __init_from_seqs(self, seqs: List[Sequence], params_str: str, ref_dataset: ' if ref_dataset: self.init_from_ref_dataset(total_nrow, ref_dataset) else: - sample_cnt = self.params.get("bin_construct_sample_cnt", DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT) + sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT sample_cnt = min(sample_cnt, total_nrow) sample_data, col_indices = self.__sample(seqs, total_nrow) From 2fdbda4e2009e1aad2024b678446bde77d6d89b2 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Wed, 26 May 2021 09:25:35 +0800 Subject: [PATCH 38/70] Remove TODO about keep constant in sync with cpp. --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 2c7fbbf7c502..0e21ea090d47 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -17,7 +17,7 @@ from .compat import PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_DataFrame, pd_Series from .libpath import find_lib_path -# TODO: how to keep the default values the same with C++ config.h + ZERO_THRESHOLD = 1e-35 DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 From 37d5e84bcb951e4fb322bca0a1497dc91edeb3b9 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Wed, 26 May 2021 09:30:08 +0800 Subject: [PATCH 39/70] Install h5py only when running python-examples. --- .ci/test_windows.ps1 | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index 0254f5f33980..12735018a499 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -49,7 +49,7 @@ if ($env:TASK -eq "swig") { Exit 0 } -conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy h5py ; Check-Output $? +conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest python-graphviz scikit-learn scipy ; Check-Output $? if ($env:TASK -eq "regular") { mkdir $env:BUILD_SOURCESDIRECTORY/build; cd $env:BUILD_SOURCESDIRECTORY/build @@ -105,6 +105,7 @@ if (($env:TASK -eq "regular") -or (($env:APPVEYOR -eq "true") -and ($env:TASK -e cd $env:BUILD_SOURCESDIRECTORY/examples/python-guide @("import matplotlib", "matplotlib.use('Agg')") + (Get-Content "plot_example.py") | Set-Content "plot_example.py" (Get-Content "plot_example.py").replace('graph.render(view=True)', 'graph.render(view=False)') | Set-Content "plot_example.py" # prevent interactive window mode + conda install -q -y -n $env:CONDA_ENV h5py ; Check-Output $? foreach ($file in @(Get-ChildItem *.py)) { @("import sys, warnings", "warnings.showwarning = lambda message, category, filename, lineno, file=None, line=None: sys.stdout.write(warnings.formatwarning(message, category, filename, lineno, line))") + (Get-Content $file) | Set-Content $file python $file ; Check-Output $? From ed56e83b22b9b9ccc7f247a19ec37318d02d2a24 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 27 May 2021 06:53:58 +0800 Subject: [PATCH 40/70] Fix lint complaint. --- python-package/lightgbm/basic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 0e21ea090d47..65144d1e8226 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -17,7 +17,6 @@ from .compat import PANDAS_INSTALLED, concat, dt_DataTable, is_dtype_sparse, pd_DataFrame, pd_Series from .libpath import find_lib_path - ZERO_THRESHOLD = 1e-35 DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 From f501bb66183681f4ececd60869825cd8968a1d74 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 7 Jun 2021 08:55:56 +0800 Subject: [PATCH 41/70] Apply suggestions from code review Co-authored-by: James Lamb --- .../python-guide/dataset_from_multi_hdf5.py | 4 ++-- python-package/lightgbm/basic.py | 24 ++++++++++--------- tests/python_package_test/test_basic.py | 2 +- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/python-guide/dataset_from_multi_hdf5.py b/examples/python-guide/dataset_from_multi_hdf5.py index a425e30106f1..584cfc3e5827 100644 --- a/examples/python-guide/dataset_from_multi_hdf5.py +++ b/examples/python-guide/dataset_from_multi_hdf5.py @@ -85,8 +85,8 @@ def generate_hdf(input_fname, output_basename, batch_size): # We can store multiple datasets inside a single HDF5 file. # Separating X and Y for choosing best chunk size for data loading. - save2hdf({'Y': df1.iloc[:, :1], 'X': df1.iloc[:, 1:]}, '{}1.h5'.format(output_basename), batch_size) - save2hdf({'Y': df2.iloc[:, :1], 'X': df2.iloc[:, 1:]}, '{}2.h5'.format(output_basename), batch_size) + save2hdf({'Y': df1.iloc[:, :1], 'X': df1.iloc[:, 1:]}, f'{output_basename}1.h5', batch_size) + save2hdf({'Y': df2.iloc[:, :1], 'X': df2.iloc[:, 1:]}, f'{output_basename}2.h5', batch_size) def main(): diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 65144d1e8226..8506d4cd54b0 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -644,7 +644,7 @@ def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: elif isinstance(idx, slice): return np.stack(self.__get_one_line__(i) for i in range(idx.start, idx.stop)) else: - raise TypeError("Sequence index must be integer or slice, got {}".format(type(idx))) + raise TypeError(f"Sequence index must be integer or slice, got {type(idx)}") Parameters ---------- @@ -656,11 +656,11 @@ def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: result : numpy 1-D array, numpy 2-D array 1-D array if idx is int, 2-D array if idx is slice """ - raise NotImplementedError("remove this line if subclassing") + raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()") def __len__(self) -> int: """Return row count of this sequence.""" - raise NotImplementedError + raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __len__()") class _InnerPredictor: @@ -1227,8 +1227,10 @@ def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dat Parameters ---------- - total_nrow (int): number of rows expected to add to dataset - ref_dataset (Dataset): referance dataset to extract meta from + total_nrow : int + number of rows expected to add to dataset + ref_dataset : Dataset + reference dataset to extract meta from Returns ------- @@ -1254,13 +1256,13 @@ def init_from_sample( Parameters ---------- - sample_data: + sample_data : list of numpy arrays Sample data for each column - sample_indices: + sample_indices : list of numpy arrays Sample data row index for each column. - sample_cnt: + sample_cnt : int Number of samples. - total_nrow: + total_nrow : int Total number of rows for all input file. Returns @@ -1273,9 +1275,9 @@ def init_from_sample( for i in range(ncol): if sample_data[i].dtype != np.double: - raise ValueError("sample_data[{}] type {} is not double".format(i, sample_data[i].dtype)) + raise ValueError(f"sample_data[{i}] type {sample_data[i].dtype} is not double") if sample_indices[i].dtype != np.int32: - raise ValueError("sample_indices[{}] type {} is not int32".format(i, sample_indices[i].dtype)) + raise ValueError(f"sample_indices[{i}] type {sample_indices[i].dtype} is not int32") # c type: double** # each double* element points to start of each column of sample data. diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 2b3b3d929fa9..15a599a1bba9 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -105,7 +105,7 @@ def __getitem__(self, idx): raise NotImplementedError("No need to implement, caller will not set step by now") return self.ndarray[idx.start: idx.stop] else: - raise TypeError("Sequence Index must be an integer/list/slice, got {}".format(type(idx))) + raise TypeError(f"Sequence Index must be an integer/list/slice, got {type(idx)}") def _get_one_line(self, idx): return self.ndarray[idx] From f5f270a10649b04ff6dfeda2e7f6795273f97d68 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 7 Jun 2021 09:47:50 +0800 Subject: [PATCH 42/70] Doc fixes, remove unused params_str in __init_from_seqs. --- python-package/lightgbm/basic.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 8506d4cd54b0..d6de77c6b4e5 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1198,14 +1198,14 @@ def create_sample_indices(self, total_nrow: int) -> np.ndarray: Parameters ---------- - total_nrow: int + total_nrow : int Total number of rows to sample from. If this value is greater than the value of parameter ``bin_construct_sample_cnt``, only ``bin_construct_sample_cnt`` indices will be used. If Dataset has multiple input data, this should be the sum of rows of every file. Returns ------- - indices: numpy array + indices : numpy array Indices for sampled data. """ param_str = param_dict_to_str(self.params) @@ -1482,11 +1482,11 @@ def _lazy_init(self, data, label=None, reference=None, if all(isinstance(x, np.ndarray) for x in data): self.__init_from_list_np2d(data, params_str, ref_dataset) elif all(Sequence.is_class(x) for x in data): - self.__init_from_seqs(data, params_str, ref_dataset) + self.__init_from_seqs(data, ref_dataset) else: raise TypeError('Data list can only be of ndarray or Sequence') elif Sequence.is_class(data): - self.__init_from_seqs([data], params_str, ref_dataset) + self.__init_from_seqs([data], ref_dataset) elif isinstance(data, dt_DataTable): self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset) else: @@ -1556,9 +1556,9 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr return filtered, filtered_idx - def __init_from_seqs(self, seqs: List[Sequence], params_str: str, ref_dataset: 'Dataset'): + def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: 'Dataset'): """ - Initialize data from a Sequence object. + Initialize data from list of Sequence objects. Sequence: Generic Data Access Object Supports random access and access by batch if properly defined by user From 2faf834d899df21e73814fc3d29d6a1c2bef2aa1 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 17 Jun 2021 08:28:44 +0800 Subject: [PATCH 43/70] Apply suggestions from code review Co-authored-by: Nikita Titov --- .ci/test.sh | 3 +-- .ci/test_windows.ps1 | 2 +- examples/python-guide/README.md | 4 ++-- examples/python-guide/dataset_from_multi_hdf5.py | 12 +++++++----- include/LightGBM/c_api.h | 2 +- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index f9cac622f181..fe627ce87077 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -226,9 +226,8 @@ import matplotlib\ matplotlib.use\(\"Agg\"\)\ ' plot_example.py # prevent interactive window mode sed -i'.bak' 's/graph.render(view=True)/graph.render(view=False)/' plot_example.py - conda install -q -y -n $CONDA_ENV h5py # requirements for example + conda install -q -y -n $CONDA_ENV h5py ipywidgets notebook # requirements for examples for f in *.py **/*.py; do python $f || exit -1; done # run all examples cd $BUILD_DIRECTORY/examples/python-guide/notebooks - conda install -q -y -n $CONDA_ENV ipywidgets notebook jupyter nbconvert --ExecutePreprocessor.timeout=180 --to notebook --execute --inplace *.ipynb || exit -1 # run all notebooks fi diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index 12735018a499..70204bf495e6 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -105,7 +105,7 @@ if (($env:TASK -eq "regular") -or (($env:APPVEYOR -eq "true") -and ($env:TASK -e cd $env:BUILD_SOURCESDIRECTORY/examples/python-guide @("import matplotlib", "matplotlib.use('Agg')") + (Get-Content "plot_example.py") | Set-Content "plot_example.py" (Get-Content "plot_example.py").replace('graph.render(view=True)', 'graph.render(view=False)') | Set-Content "plot_example.py" # prevent interactive window mode - conda install -q -y -n $env:CONDA_ENV h5py ; Check-Output $? + conda install -q -y -n $env:CONDA_ENV h5py ipywidgets notebook foreach ($file in @(Get-ChildItem *.py)) { @("import sys, warnings", "warnings.showwarning = lambda message, category, filename, lineno, file=None, line=None: sys.stdout.write(warnings.formatwarning(message, category, filename, lineno, line))") + (Get-Content $file) | Set-Content $file python $file ; Check-Output $? diff --git a/examples/python-guide/README.md b/examples/python-guide/README.md index a938749bd4c3..b34e0fa6e002 100644 --- a/examples/python-guide/README.md +++ b/examples/python-guide/README.md @@ -62,5 +62,5 @@ Examples include: - Plot one specified tree - Plot one specified tree with Graphviz - [dataset_from_multi_hdf5.py](https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/dataset_from_multi_hdf5.py) - - Construct Dataset from multiple HDF5 file - - Avoids loading all data into memory + - Construct Dataset from multiple HDF5 files + - Avoid loading all data into memory diff --git a/examples/python-guide/dataset_from_multi_hdf5.py b/examples/python-guide/dataset_from_multi_hdf5.py index 584cfc3e5827..f7b6d665127b 100644 --- a/examples/python-guide/dataset_from_multi_hdf5.py +++ b/examples/python-guide/dataset_from_multi_hdf5.py @@ -8,12 +8,14 @@ class HDFSequence(lgb.Sequence): def __init__(self, hdf_dataset, batch_size): """ + Construct a sequence object from HDF5 with required interface. + Parameters ---------- - hdf_dataset: h5py.Dataset - dataset in HDF5 file - batch_size: int - when reading data to construct lightgbm Dataset, each read reads batch_size rows + hdf_dataset : h5py.Dataset + Dataset in HDF5 file. + batch_size : int + Size of a batch. When reading data to construct lightgbm Dataset, each read reads batch_size rows. """ # We can also open HDF5 file once and get access to self.data = hdf_dataset @@ -61,7 +63,7 @@ def save2hdf(input_data, fname, batch_size): nrow, ncol = data.shape if ncol == 1: # Y has a single column and we read it in single shot. So store it as an 1-d array. - chunk = (nrow, ) + chunk = (nrow,) data = data.values.flatten() else: # We use random access for data sampling when creating LightGBM Dataset from Sequence. diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index fd01123c6824..3e334341e52b 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -215,7 +215,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, DatasetHandle* out); /*! - * \brief Create sample indices for total nrow. + * \brief Create sample indices for total number of rows. * \param total_nrow Number of all data rows * \param parameters Additional parameters, specify sample count and random seed in parameter * \param[out] out Created indices, type is int32_t, caller should insure out contains enough space to hold indices From 0679dc7abdc72727300e392b14f7fad08d983d98 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 17 Jun 2021 08:33:13 +0800 Subject: [PATCH 44/70] Remove unnecessary conda install in windows ci script. --- .ci/test_windows.ps1 | 1 - 1 file changed, 1 deletion(-) diff --git a/.ci/test_windows.ps1 b/.ci/test_windows.ps1 index 70204bf495e6..cb191c0804ee 100644 --- a/.ci/test_windows.ps1 +++ b/.ci/test_windows.ps1 @@ -111,6 +111,5 @@ if (($env:TASK -eq "regular") -or (($env:APPVEYOR -eq "true") -and ($env:TASK -e python $file ; Check-Output $? } # run all examples cd $env:BUILD_SOURCESDIRECTORY/examples/python-guide/notebooks - conda install -q -y -n $env:CONDA_ENV ipywidgets notebook jupyter nbconvert --ExecutePreprocessor.timeout=180 --to notebook --execute --inplace *.ipynb ; Check-Output $? # run all notebooks } From ebdada5ddc422d1b0d840a1c764608479303a1ab Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 17 Jun 2021 08:34:08 +0800 Subject: [PATCH 45/70] Keep param as example in dataset_from_multi_hdf5.py --- .../python-guide/dataset_from_multi_hdf5.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/python-guide/dataset_from_multi_hdf5.py b/examples/python-guide/dataset_from_multi_hdf5.py index f7b6d665127b..4ec5c9fd310f 100644 --- a/examples/python-guide/dataset_from_multi_hdf5.py +++ b/examples/python-guide/dataset_from_multi_hdf5.py @@ -36,11 +36,10 @@ def create_dataset_from_multiple_hdf(input_flist, batch_size): data.append(HDFSequence(f['X'], batch_size)) ylist.append(f['Y'][:]) - # params = { - # 'bin_construct_sample_cnt': 200000, - # 'max_bin': 255, - # } - params = None + params = { + 'bin_construct_sample_cnt': 200000, + 'max_bin': 255, + } y = np.concatenate(ylist) dataset = lgb.Dataset(data, label=y, params=params) # With binary dataset created, we can use either Python API or cmdline version to train. @@ -87,17 +86,20 @@ def generate_hdf(input_fname, output_basename, batch_size): # We can store multiple datasets inside a single HDF5 file. # Separating X and Y for choosing best chunk size for data loading. - save2hdf({'Y': df1.iloc[:, :1], 'X': df1.iloc[:, 1:]}, f'{output_basename}1.h5', batch_size) - save2hdf({'Y': df2.iloc[:, :1], 'X': df2.iloc[:, 1:]}, f'{output_basename}2.h5', batch_size) + fname1 = f'{output_basename}1.h5' + fname2 = f'{output_basename}2.h5' + save2hdf({'Y': df1.iloc[:, :1], 'X': df1.iloc[:, 1:]}, fname1, batch_size) + save2hdf({'Y': df2.iloc[:, :1], 'X': df2.iloc[:, 1:]}, fname2, batch_size) + + return [fname1, fname2] def main(): batch_size = 64 - generate_hdf('../regression/regression.train', 'regression', 64) - create_dataset_from_multiple_hdf( - ['regression1.h5', 'regression2.h5'], - batch_size=batch_size, - ) + output_basename = 'regression' + hdf_files = generate_hdf('../regression/regression.train', output_basename, 64) + + create_dataset_from_multiple_hdf(hdf_files, batch_size=batch_size) if __name__ == '__main__': From 96ebbfe74991bb915a96e753a200cf8772bac166 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 17 Jun 2021 08:43:13 +0800 Subject: [PATCH 46/70] Add _get_sample_count function to remove code duplication. --- python-package/lightgbm/basic.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d6de77c6b4e5..0ecf66a6cdc5 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -21,6 +21,11 @@ DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 +def _get_sample_count(params: Dict[str, str], total_nrow: int): + sample_count = params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT + return min(sample_count, total_nrow) + + class _DummyLogger: def info(self, msg): print(msg) @@ -1210,8 +1215,7 @@ def create_sample_indices(self, total_nrow: int) -> np.ndarray: """ param_str = param_dict_to_str(self.params) # Note self.params may contain 'bin_construct_sample_cnt' but is None. - sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT - sample_cnt = min(sample_cnt, total_nrow) + sample_cnt = _get_sample_count(self.params, total_nrow) indices = np.zeros(sample_cnt, dtype=np.int32) ptr_data, _, _ = c_int_array(indices) @@ -1571,8 +1575,7 @@ def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: 'Dataset'): if ref_dataset: self.init_from_ref_dataset(total_nrow, ref_dataset) else: - sample_cnt = self.params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT - sample_cnt = min(sample_cnt, total_nrow) + sample_cnt = _get_sample_count(self.params, total_nrow) sample_data, col_indices = self.__sample(seqs, total_nrow) self.init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) From 12dcdff53ffbfc63f8605f074ba01f996f100e24 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Thu, 17 Jun 2021 09:05:52 +0800 Subject: [PATCH 47/70] Use batch_size parameter in generate_hdf. --- examples/python-guide/dataset_from_multi_hdf5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/python-guide/dataset_from_multi_hdf5.py b/examples/python-guide/dataset_from_multi_hdf5.py index 4ec5c9fd310f..98022679af33 100644 --- a/examples/python-guide/dataset_from_multi_hdf5.py +++ b/examples/python-guide/dataset_from_multi_hdf5.py @@ -97,7 +97,7 @@ def generate_hdf(input_fname, output_basename, batch_size): def main(): batch_size = 64 output_basename = 'regression' - hdf_files = generate_hdf('../regression/regression.train', output_basename, 64) + hdf_files = generate_hdf('../regression/regression.train', output_basename, batch_size) create_dataset_from_multiple_hdf(hdf_files, batch_size=batch_size) From e8cb2ee59c88151747056e2372141a891a6215bf Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 18 Jun 2021 09:10:31 +0800 Subject: [PATCH 48/70] Apply suggestions from code review Co-authored-by: Nikita Titov --- python-package/lightgbm/basic.py | 39 ++++++++++++++++---------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 0ecf66a6cdc5..0e08cf496c7f 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -21,7 +21,7 @@ DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 -def _get_sample_count(params: Dict[str, str], total_nrow: int): +def _get_sample_count(params: Dict[str, Any], total_nrow: int): sample_count = params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT return min(sample_count, total_nrow) @@ -659,7 +659,7 @@ def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: Returns ------- result : numpy 1-D array, numpy 2-D array - 1-D array if idx is int, 2-D array if idx is slice + 1-D array if idx is int, 2-D array if idx is slice. """ raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()") @@ -1196,7 +1196,7 @@ def __del__(self): except AttributeError: pass - def create_sample_indices(self, total_nrow: int) -> np.ndarray: + def _create_sample_indices(self, total_nrow: int) -> np.ndarray: """Get an array of randomly chosen indices from this ``Dataset``. Indices are sampled without replacement. @@ -1213,28 +1213,28 @@ def create_sample_indices(self, total_nrow: int) -> np.ndarray: indices : numpy array Indices for sampled data. """ - param_str = param_dict_to_str(self.params) + param_str = param_dict_to_str(self.get_params()) # Note self.params may contain 'bin_construct_sample_cnt' but is None. - sample_cnt = _get_sample_count(self.params, total_nrow) - indices = np.zeros(sample_cnt, dtype=np.int32) + sample_cnt = _get_sample_count(self.get_params(), total_nrow) + indices = np.empty(sample_cnt, dtype=np.int32) ptr_data, _, _ = c_int_array(indices) _safe_call(_LIB.LGBM_SampleIndices( - ctypes.c_int(total_nrow), + ctypes.c_int32(total_nrow), c_str(param_str), ptr_data, )) return indices - def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dataset': + def _init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dataset': """Create dataset from a reference dataset. Parameters ---------- total_nrow : int - number of rows expected to add to dataset + Number of rows expected to add to dataset. ref_dataset : Dataset - reference dataset to extract meta from + Reference dataset to extract meta from. Returns ------- @@ -1249,7 +1249,7 @@ def init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dat )) return self - def init_from_sample( + def _init_from_sample( self, sample_data: List[np.ndarray], sample_indices: List[np.ndarray], @@ -1261,13 +1261,13 @@ def init_from_sample( Parameters ---------- sample_data : list of numpy arrays - Sample data for each column + Sample data for each column. sample_indices : list of numpy arrays Sample data row index for each column. sample_cnt : int Number of samples. total_nrow : int - Total number of rows for all input file. + Total number of rows for all input files. Returns ------- @@ -1297,7 +1297,7 @@ def init_from_sample( num_per_col_ptr, _, _ = c_int_array(num_per_col) self.handle = ctypes.c_void_p() - params_str = param_dict_to_str(self.params) + params_str = param_dict_to_str(self.get_params()) _safe_call(_LIB.LGBM_DatasetCreateFromSampledColumn( ctypes.cast(sample_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_double))), ctypes.cast(indices_col_ptr, ctypes.POINTER(ctypes.POINTER(ctypes.c_int32))), @@ -1310,7 +1310,7 @@ def init_from_sample( )) return self - def push_rows(self, data: np.ndarray) -> 'Dataset': + def _push_rows(self, data: np.ndarray) -> 'Dataset': """Add rows to Dataset. Parameters @@ -1529,7 +1529,7 @@ def __yield_row_from(self, seqs: List[Sequence], indices: Iterable[int]): seq_id += 1 seq = seqs[seq_id] id_in_seq = row_id - offset - row = seq[int(id_in_seq)] + row = seq[id_in_seq] yield row if row.flags['OWNDATA'] else row.copy() def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarray], List[np.ndarray]]: @@ -1560,7 +1560,7 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr return filtered, filtered_idx - def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: 'Dataset'): + def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: Optional['Dataset'] = None): """ Initialize data from list of Sequence objects. @@ -1572,10 +1572,10 @@ def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: 'Dataset'): total_nrow = sum(len(seq) for seq in seqs) # create validation dataset from ref_dataset - if ref_dataset: + if ref_dataset is not None: self.init_from_ref_dataset(total_nrow, ref_dataset) else: - sample_cnt = _get_sample_count(self.params, total_nrow) + sample_cnt = _get_sample_count(self.get_params(), total_nrow) sample_data, col_indices = self.__sample(seqs, total_nrow) self.init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) @@ -1586,6 +1586,7 @@ def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: 'Dataset'): for start in range(0, nrow, batch_size): end = min(start + batch_size, nrow) self.push_rows(seq[start:end]) + return self def __init_from_np2d(self, mat, params_str, ref_dataset): """Initialize data from a 2-D numpy matrix.""" From 8e614b3b81b7dd36bef9cc4436b350472274c391 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 18 Jun 2021 09:15:02 +0800 Subject: [PATCH 49/70] Fix after applying suggestions. --- python-package/lightgbm/basic.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 0e08cf496c7f..7a84d8b8a505 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -9,7 +9,7 @@ from functools import wraps from logging import Logger from tempfile import NamedTemporaryFile -from typing import Any, Dict, Iterable, List, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Set, Tuple, Union, Optional import numpy as np import scipy.sparse @@ -22,6 +22,7 @@ def _get_sample_count(params: Dict[str, Any], total_nrow: int): + # Note self.params may contain 'bin_construct_sample_cnt' but is None. sample_count = params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT return min(sample_count, total_nrow) @@ -623,7 +624,7 @@ class Sequence: @staticmethod def is_class(obj) -> bool: - """Check if objection is instance of class Sequence. + """Check if object is instance of class Sequence. Args: ------- @@ -1214,7 +1215,6 @@ def _create_sample_indices(self, total_nrow: int) -> np.ndarray: Indices for sampled data. """ param_str = param_dict_to_str(self.get_params()) - # Note self.params may contain 'bin_construct_sample_cnt' but is None. sample_cnt = _get_sample_count(self.get_params(), total_nrow) indices = np.empty(sample_cnt, dtype=np.int32) ptr_data, _, _ = c_int_array(indices) @@ -1541,7 +1541,7 @@ def __sample(self, seqs: List[Sequence], total_nrow: int) -> Tuple[List[np.ndarr ------- sampled_rows, sampled_row_indices """ - indices = self.create_sample_indices(total_nrow) + indices = self._create_sample_indices(total_nrow) # Select sampled rows, transpose to column order. sampled = np.array([row for row in self.__yield_row_from(seqs, indices)]) @@ -1573,19 +1573,19 @@ def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: Optional['Dataset' # create validation dataset from ref_dataset if ref_dataset is not None: - self.init_from_ref_dataset(total_nrow, ref_dataset) + self._init_from_ref_dataset(total_nrow, ref_dataset) else: sample_cnt = _get_sample_count(self.get_params(), total_nrow) sample_data, col_indices = self.__sample(seqs, total_nrow) - self.init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) + self._init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) for seq in seqs: nrow = len(seq) batch_size = seq.batch_size or Sequence.batch_size for start in range(0, nrow, batch_size): end = min(start + batch_size, nrow) - self.push_rows(seq[start:end]) + self._push_rows(seq[start:end]) return self def __init_from_np2d(self, mat, params_str, ref_dataset): From b75ee56f485d3bb60134461f7b1a3b6d0c2b4001 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 18 Jun 2021 09:22:59 +0800 Subject: [PATCH 50/70] Fix test, check idx is instance of numbers.Integral. --- tests/python_package_test/test_basic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 15a599a1bba9..294abcc78142 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -1,5 +1,6 @@ # coding: utf-8 import filecmp +import numbers import os import numpy as np @@ -98,7 +99,7 @@ def __init__(self, ndarray, batch_size): def __getitem__(self, idx): # The simple implementation is just a single "return self.ndarray[idx]" # The following is for demo and testing purpose. - if isinstance(idx, int): + if isinstance(idx, numbers.Integral): return self._get_one_line(idx) elif isinstance(idx, slice): if not (idx.step is None or idx.step is 1): From 4c2321037916f630e6d2b9bc511e08c072469e31 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 18 Jun 2021 09:49:22 +0800 Subject: [PATCH 51/70] Update python-package/lightgbm/basic.py Co-authored-by: Nikita Titov --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 7a84d8b8a505..99a7f06be573 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1324,7 +1324,7 @@ def _push_rows(self, data: np.ndarray) -> 'Dataset': Dataset object. """ nrow, ncol = data.shape - data = np.array(data.reshape(data.size), dtype=data.dtype, copy=False) + data = data.reshape(data.size) data_ptr, data_type, _ = c_float_array(data) _safe_call(_LIB.LGBM_DatasetPushRows( From aadc12522d266188bd19a1d8fdcd95bbc924202c Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 18 Jun 2021 09:52:43 +0800 Subject: [PATCH 52/70] Expose Sequence class in Python-API doc. --- docs/Python-API.rst | 1 + python-package/lightgbm/basic.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/docs/Python-API.rst b/docs/Python-API.rst index a50ff43f0a4e..f57125c46a5a 100644 --- a/docs/Python-API.rst +++ b/docs/Python-API.rst @@ -12,6 +12,7 @@ Data Structure API Dataset Booster CVBooster + Sequence Training API ------------ diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 99a7f06be573..3dd425fc8a2d 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -604,7 +604,7 @@ class Sequence: """ Generic data access interface. - Object should support the following operations: + Object should support the following operations:: # Get total row number. >>> len(seq) @@ -615,24 +615,25 @@ class Sequence: # Optionally specify batch_size to control range data read size. >>> seq.batch_size - With random access, data sampling does not need to go through all data. - With range data access, there's no need to read all data into memory thus - reduce memory usage. + - With random access, **data sampling does not need to go through all data**. + - With range data access, there's **no need to read all data into memory thus reduce memory usage**. """ batch_size = 4096 # Defaults to read 4K rows in each batch. @staticmethod def is_class(obj) -> bool: - """Check if object is instance of class Sequence. + """Check whether object satisfies ``Sequence`` interface requirements. - Args: - ------- - obj ([any]): object to be checked + Parameters + ---------- + obj: Any + object to be checked. Returns ------- - [bool]: is Sequence class + result: bool + ``True`` if object satisfies ``Sequence`` interface requirements, ``False`` otherwise. """ if isinstance(obj, list) or hasattr(obj, "getformat"): return False From b857b6051475945106f8a1284818382305965c55 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 18 Jun 2021 10:07:27 +0800 Subject: [PATCH 53/70] Handle Sequence object not having batch_size. --- python-package/lightgbm/basic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 3dd425fc8a2d..aa5671404818 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -635,6 +635,7 @@ def is_class(obj) -> bool: result: bool ``True`` if object satisfies ``Sequence`` interface requirements, ``False`` otherwise. """ + # Sparse matrix also have __getitem__ and __len__, so we have to exclude them here. if isinstance(obj, list) or hasattr(obj, "getformat"): return False return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") @@ -1583,7 +1584,7 @@ def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: Optional['Dataset' for seq in seqs: nrow = len(seq) - batch_size = seq.batch_size or Sequence.batch_size + batch_size = getattr(seq, 'batch_size', None) or Sequence.batch_size for start in range(0, nrow, batch_size): end = min(start + batch_size, nrow) self._push_rows(seq[start:end]) From dd0ce1d2302a4b20ceb57a52fac55471462bbd2b Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Fri, 18 Jun 2021 10:19:23 +0800 Subject: [PATCH 54/70] Fix isort lint complaint. --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index aa5671404818..00abe5559831 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -9,7 +9,7 @@ from functools import wraps from logging import Logger from tempfile import NamedTemporaryFile -from typing import Any, Dict, Iterable, List, Set, Tuple, Union, Optional +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import numpy as np import scipy.sparse From 4b68e561271bd286aee5687f19d33d676138109c Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 22 Jun 2021 07:47:07 +0800 Subject: [PATCH 55/70] Apply suggestions from code review Co-authored-by: Nikita Titov --- python-package/lightgbm/basic.py | 5 +++++ tests/python_package_test/test_basic.py | 10 +++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 00abe5559831..d09b01f3c65a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -617,6 +617,11 @@ class Sequence: - With random access, **data sampling does not need to go through all data**. - With range data access, there's **no need to read all data into memory thus reduce memory usage**. + + Attributes + ---------- + batch_size : int + Default size of a batch. """ batch_size = 4096 # Defaults to read 4K rows in each batch. diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 294abcc78142..ba75992ff9fb 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -102,9 +102,9 @@ def __getitem__(self, idx): if isinstance(idx, numbers.Integral): return self._get_one_line(idx) elif isinstance(idx, slice): - if not (idx.step is None or idx.step is 1): + if not (idx.step is None or idx.step == 1): raise NotImplementedError("No need to implement, caller will not set step by now") - return self.ndarray[idx.start: idx.stop] + return self.ndarray[idx.start:idx.stop] else: raise TypeError(f"Sequence Index must be an integer/list/slice, got {type(idx)}") @@ -151,14 +151,14 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ # half col if include_nan: # nan col - data[0:half_nrow, 2] = np.nan + data[:half_nrow, 2] = np.nan if include_0: # 0 col - data[0:half_nrow, 3] = 0 + data[:half_nrow, 3] = 0 # nan + 0 col if include_nan: - data[0:half_nrow, 4] = np.nan + data[:half_nrow, 4] = np.nan if include_0: data[half_nrow:-2, 4] = 0 From 75e3fff4addc1524ed5f7219d5fa885e7c9c2c64 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 22 Jun 2021 07:51:30 +0800 Subject: [PATCH 56/70] Update docstring to mention Sequence as data input. --- python-package/lightgbm/basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d09b01f3c65a..d6f2b49df1ef 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1140,7 +1140,7 @@ def __init__(self, data, label=None, reference=None, Parameters ---------- - data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse or list of numpy arrays + data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays Data source of Dataset. If string, it represents the path to txt file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) @@ -1778,7 +1778,7 @@ def create_valid(self, data, label=None, weight=None, group=None, Parameters ---------- - data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse or list of numpy arrays + data : string, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequences or list of numpy arrays Data source of Dataset. If string, it represents the path to txt file. label : list, numpy 1-D array, pandas Series / one-column DataFrame or None, optional (default=None) From 4c2e8069c3ba3a1642119ccdd847a75afd563e05 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 22 Jun 2021 07:58:01 +0800 Subject: [PATCH 57/70] Remove get_one_line in test_basic.py --- tests/python_package_test/test_basic.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index ba75992ff9fb..390a620a8469 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -100,7 +100,7 @@ def __getitem__(self, idx): # The simple implementation is just a single "return self.ndarray[idx]" # The following is for demo and testing purpose. if isinstance(idx, numbers.Integral): - return self._get_one_line(idx) + return self.ndarray[idx] elif isinstance(idx, slice): if not (idx.step is None or idx.step == 1): raise NotImplementedError("No need to implement, caller will not set step by now") @@ -108,9 +108,6 @@ def __getitem__(self, idx): else: raise TypeError(f"Sequence Index must be an integer/list/slice, got {type(idx)}") - def _get_one_line(self, idx): - return self.ndarray[idx] - def __len__(self): return len(self.ndarray) From 60985d9f88d3e2c3ee79594b9d82c6f85fd1d566 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 22 Jun 2021 07:55:19 +0800 Subject: [PATCH 58/70] Make Sequence an abstract class. --- python-package/lightgbm/basic.py | 28 ++++++---------------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d6f2b49df1ef..d2dbd68a6617 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1,5 +1,6 @@ # coding: utf-8 """Wrapper for C API of LightGBM.""" +import abc import ctypes import json import os @@ -600,7 +601,7 @@ def _load_pandas_categorical(file_name=None, model_str=None): return None -class Sequence: +class Sequence(abc.ABC): """ Generic data access interface. @@ -626,25 +627,7 @@ class Sequence: batch_size = 4096 # Defaults to read 4K rows in each batch. - @staticmethod - def is_class(obj) -> bool: - """Check whether object satisfies ``Sequence`` interface requirements. - - Parameters - ---------- - obj: Any - object to be checked. - - Returns - ------- - result: bool - ``True`` if object satisfies ``Sequence`` interface requirements, ``False`` otherwise. - """ - # Sparse matrix also have __getitem__ and __len__, so we have to exclude them here. - if isinstance(obj, list) or hasattr(obj, "getformat"): - return False - return hasattr(obj, "__getitem__") and hasattr(obj, "__len__") - + @abc.abstractmethod def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: """Return data for given row index. @@ -671,6 +654,7 @@ def __getitem__(self, idx: Union[int, slice]) -> np.ndarray: """ raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __getitem__()") + @abc.abstractmethod def __len__(self) -> int: """Return row count of this sequence.""" raise NotImplementedError("Sub-classes of lightgbm.Sequence must implement __len__()") @@ -1492,11 +1476,11 @@ def _lazy_init(self, data, label=None, reference=None, elif isinstance(data, list) and len(data) > 0: if all(isinstance(x, np.ndarray) for x in data): self.__init_from_list_np2d(data, params_str, ref_dataset) - elif all(Sequence.is_class(x) for x in data): + elif all(isinstance(x, Sequence) for x in data): self.__init_from_seqs(data, ref_dataset) else: raise TypeError('Data list can only be of ndarray or Sequence') - elif Sequence.is_class(data): + elif isinstance(data, Sequence): self.__init_from_seqs([data], ref_dataset) elif isinstance(data, dt_DataTable): self.__init_from_np2d(data.to_numpy(), params_str, ref_dataset) From 20dac553b11e72350245ef19a29039d8d873fa96 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 22 Jun 2021 08:21:26 +0800 Subject: [PATCH 59/70] Reduce number of tests for test_sequence. --- tests/python_package_test/test_basic.py | 27 +++++++++---------------- 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 390a620a8469..32038613d84f 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -126,12 +126,11 @@ def _create_sequence_from_ndarray(data, num_seq, batch_size): return seqs -@pytest.mark.parametrize('sample_count', [11, 23, 100, None]) -@pytest.mark.parametrize('batch_size', [3, 20, None]) -@pytest.mark.parametrize('include_0', [False, True]) -@pytest.mark.parametrize('include_nan', [False, True]) +@pytest.mark.parametrize('sample_count', [11, 100, None]) +@pytest.mark.parametrize('batch_size', [3, None]) +@pytest.mark.parametrize('include_0_and_nan', [False, True]) @pytest.mark.parametrize('num_seq', [1, 3]) -def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_seq): +def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq): params = {'bin_construct_sample_cnt': sample_count} nrow = 50 @@ -139,25 +138,17 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0, include_nan, num_ ncol = 11 data = np.arange(nrow * ncol, dtype=np.float64).reshape((nrow, ncol)) - # total col - if include_0: + if include_0_and_nan: + # whole col data[:, 0] = 0 - if include_nan: data[:, 1] = np.nan - # half col - if include_nan: - # nan col - data[:half_nrow, 2] = np.nan - if include_0: - # 0 col + # half col data[:half_nrow, 3] = 0 + data[:half_nrow, 2] = np.nan - # nan + 0 col - if include_nan: - data[:half_nrow, 4] = np.nan - if include_0: data[half_nrow:-2, 4] = 0 + data[:half_nrow, 4] = np.nan X = data[:, :-1] Y = data[:, -1] From 60fe7d575c148519f8adfe6a40fa7edfdcdfd5c0 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 22 Jun 2021 09:20:16 +0800 Subject: [PATCH 60/70] Add c_api: LGBM_SampleCount, fix potential bug in LGBMSampleIndices. --- include/LightGBM/c_api.h | 35 ++++++++++----- python-package/lightgbm/basic.py | 21 ++++++--- src/c_api.cpp | 73 ++++++++++++++++++++------------ 3 files changed, 83 insertions(+), 46 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 3e334341e52b..e85adf17f288 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -53,6 +53,30 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError(); */ LIGHTGBM_C_EXPORT int LGBM_RegisterLogCallback(void (*callback)(const char*)); +/*! + * \brief Get number of samples based on parameter and total number of rows of data. + * \param total_nrow Number of all data rows + * \param parameters Additional parameters, specify sample count + * \param[out] out Number of samples. You should pre-allocate memory to hold sample indices when calling ``LGBM_SampleIndices``. + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_SampleCount(int32_t total_nrow, + const char* parameters, + int* out); + +/*! + * \brief Create sample indices for total number of rows. + * \param total_nrow Number of all data rows + * \param parameters Additional parameters, specify sample count and random seed in parameter + * \param[out] out Created indices, type is int32_t, caller should insure out contains enough space to hold indices + * \param[out_len] out Number of indices. This maybe less than the one returned by ``LGBM_SampleCount``. + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t total_nrow, + const char* parameters, + void* out, + int64_t* out_len); + // --- start Dataset interface /*! @@ -214,17 +238,6 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromCSC(const void* col_ptr, const DatasetHandle reference, DatasetHandle* out); -/*! - * \brief Create sample indices for total number of rows. - * \param total_nrow Number of all data rows - * \param parameters Additional parameters, specify sample count and random seed in parameter - * \param[out] out Created indices, type is int32_t, caller should insure out contains enough space to hold indices - * \return 0 when succeed, -1 when failure happens - */ -LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t total_nrow, - const char* parameters, - void* out); - /*! * \brief Create dataset from dense matrix. * \param data Pointer to the data space diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d2dbd68a6617..68d30a4b0aaa 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -22,10 +22,14 @@ DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 -def _get_sample_count(params: Dict[str, Any], total_nrow: int): - # Note self.params may contain 'bin_construct_sample_cnt' but is None. - sample_count = params.get("bin_construct_sample_cnt") or DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT - return min(sample_count, total_nrow) +def _get_sample_count(total_nrow: int, params: str): + sample_cnt = ctypes.c_int(0) + _safe_call(_LIB.LGBM_SampleCount( + ctypes.c_int32(total_nrow), + c_str(params), + ctypes.byref(sample_cnt), + )) + return sample_cnt.value class _DummyLogger: @@ -1206,16 +1210,18 @@ def _create_sample_indices(self, total_nrow: int) -> np.ndarray: Indices for sampled data. """ param_str = param_dict_to_str(self.get_params()) - sample_cnt = _get_sample_count(self.get_params(), total_nrow) + sample_cnt = _get_sample_count(total_nrow, param_str) indices = np.empty(sample_cnt, dtype=np.int32) ptr_data, _, _ = c_int_array(indices) + actual_sample_cnt = ctypes.c_int64(0) _safe_call(_LIB.LGBM_SampleIndices( ctypes.c_int32(total_nrow), c_str(param_str), ptr_data, + ctypes.byref(actual_sample_cnt), )) - return indices + return indices[:actual_sample_cnt.value] def _init_from_ref_dataset(self, total_nrow: int, ref_dataset: 'Dataset') -> 'Dataset': """Create dataset from a reference dataset. @@ -1566,7 +1572,8 @@ def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: Optional['Dataset' if ref_dataset is not None: self._init_from_ref_dataset(total_nrow, ref_dataset) else: - sample_cnt = _get_sample_count(self.get_params(), total_nrow) + param_str = param_dict_to_str(self.get_params()) + sample_cnt = _get_sample_count(total_nrow, param_str) sample_data, col_indices = self.__sample(seqs, total_nrow) self._init_from_sample(sample_data, col_indices, sample_cnt, total_nrow) diff --git a/src/c_api.cpp b/src/c_api.cpp index 95cfe1db8085..1b13e89ad6b2 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -894,6 +894,51 @@ int LGBM_RegisterLogCallback(void (*callback)(const char*)) { API_END(); } +static inline int SampleCount(int32_t total_nrow, const Config& config) { + return static_cast(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt); +} + +static inline std::vector CreateSampleIndices(int32_t total_nrow, const Config& config) { + Random rand(config.data_random_seed); + int sample_cnt = SampleCount(total_nrow, config); + return rand.Sample(total_nrow, sample_cnt); +} + +int LGBM_SampleCount(int32_t total_nrow, + const char* parameters, + int* out) { + API_BEGIN(); + if (out == nullptr) { + Log::Fatal("LGBM_SampleCount output is nullptr"); + } + auto param = Config::Str2Map(parameters); + Config config; + config.Set(param); + + *out = SampleCount(total_nrow, config); + API_END(); +} + +int LGBM_SampleIndices(int32_t total_nrow, + const char* parameters, + void* out, + int64_t* out_len) { + // This API is to keep python binding's behavior the same with C++ implementation. + // Sample count, random seed etc. should be provided in parameters. + API_BEGIN(); + if (out == nullptr) { + Log::Fatal("LGBM_SampleIndices output is nullptr"); + } + auto param = Config::Str2Map(parameters); + Config config; + config.Set(param); + + auto sample_indices = CreateSampleIndices(total_nrow, config); + memcpy(out, sample_indices.data(), sizeof(int32_t) * sample_indices.size()); + *out_len = sample_indices.size(); + API_END(); +} + int LGBM_DatasetCreateFromFile(const char* filename, const char* parameters, const DatasetHandle reference, @@ -1034,34 +1079,6 @@ int LGBM_DatasetCreateFromMat(const void* data, out); } - -static inline std::vector CreateSampleIndices(const Config& config, int32_t total_nrow) { - Random rand(config.data_random_seed); - int sample_cnt = static_cast(total_nrow < config.bin_construct_sample_cnt ? total_nrow : config.bin_construct_sample_cnt); - return rand.Sample(total_nrow, sample_cnt); -} - - -int LGBM_SampleIndices(int32_t total_nrow, - const char* parameters, - void* out) { - // This API is to keep python binding's behavior the same with C++ implementation. - // Sample count, random seed etc. should be provided in parameters. - API_BEGIN(); - if (out == nullptr) { - Log::Fatal("sample indices output is nullptr"); - } - auto param = Config::Str2Map(parameters); - Config config; - config.Set(param); - - auto sample_indices = CreateSampleIndices(config, total_nrow); - memcpy(out, sample_indices.data(), sizeof(int32_t) * sample_indices.size()); - - API_END(); -} - - int LGBM_DatasetCreateFromMats(int32_t nmat, const void** data, int data_type, From 79465a53855a672a97b412719cd9256b2236c8de Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 22 Jun 2021 10:00:30 +0800 Subject: [PATCH 61/70] empty commit to trigger ci From 90d342daa2231d68a2104e98382bf2c5d71ef242 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 28 Jun 2021 07:59:06 +0800 Subject: [PATCH 62/70] Apply suggestions from code review Co-authored-by: Nikita Titov --- include/LightGBM/c_api.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index e85adf17f288..82874cdf2a1e 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -54,22 +54,24 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError(); LIGHTGBM_C_EXPORT int LGBM_RegisterLogCallback(void (*callback)(const char*)); /*! - * \brief Get number of samples based on parameter and total number of rows of data. + * \brief Get number of samples based on parameters and total number of rows of data. * \param total_nrow Number of all data rows - * \param parameters Additional parameters, specify sample count - * \param[out] out Number of samples. You should pre-allocate memory to hold sample indices when calling ``LGBM_SampleIndices``. + * \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` is used to calculate returned value + * \param[out] out Number of samples. This value is used to pre-allocate memory to hold sample indices when calling ``LGBM_SampleIndices`` * \return 0 when succeed, -1 when failure happens */ -LIGHTGBM_C_EXPORT int LGBM_SampleCount(int32_t total_nrow, +LIGHTGBM_C_EXPORT int LGBM_GetSampleCount(int32_t total_nrow, const char* parameters, int* out); /*! * \brief Create sample indices for total number of rows. + * \note + * You should pre-allocate memory for ``out``, you can get its length by ``LGBM_SampleCount``. * \param total_nrow Number of all data rows - * \param parameters Additional parameters, specify sample count and random seed in parameter - * \param[out] out Created indices, type is int32_t, caller should insure out contains enough space to hold indices - * \param[out_len] out Number of indices. This maybe less than the one returned by ``LGBM_SampleCount``. + * \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` and ``data_random_seed`` are used to produce the output + * \param[out] out Created indices, type is int32_t + * \param[out] out_len Number of indices. This maybe less than the one returned by ``LGBM_SampleCount`` * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t total_nrow, From a5bb2e6dd71c1448bff82de78279e7d1dbebbd40 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 28 Jun 2021 08:26:45 +0800 Subject: [PATCH 63/70] Rename to LGBM_GetSampleCount, change LGBM_SampleIndices out_len to int32_t. Also rename total_nrow to num_total_row in c_api.h for consistency. --- include/LightGBM/c_api.h | 10 +++++----- python-package/lightgbm/basic.py | 6 +++--- src/c_api.cpp | 12 ++++++------ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 82874cdf2a1e..6ed3148848d4 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -55,12 +55,12 @@ LIGHTGBM_C_EXPORT int LGBM_RegisterLogCallback(void (*callback)(const char*)); /*! * \brief Get number of samples based on parameters and total number of rows of data. - * \param total_nrow Number of all data rows + * \param num_total_row Number of total rows * \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` is used to calculate returned value * \param[out] out Number of samples. This value is used to pre-allocate memory to hold sample indices when calling ``LGBM_SampleIndices`` * \return 0 when succeed, -1 when failure happens */ -LIGHTGBM_C_EXPORT int LGBM_GetSampleCount(int32_t total_nrow, +LIGHTGBM_C_EXPORT int LGBM_GetSampleCount(int32_t num_total_row, const char* parameters, int* out); @@ -68,16 +68,16 @@ LIGHTGBM_C_EXPORT int LGBM_GetSampleCount(int32_t total_nrow, * \brief Create sample indices for total number of rows. * \note * You should pre-allocate memory for ``out``, you can get its length by ``LGBM_SampleCount``. - * \param total_nrow Number of all data rows + * \param num_total_row Number of total rows * \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` and ``data_random_seed`` are used to produce the output * \param[out] out Created indices, type is int32_t * \param[out] out_len Number of indices. This maybe less than the one returned by ``LGBM_SampleCount`` * \return 0 when succeed, -1 when failure happens */ -LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t total_nrow, +LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t num_total_row, const char* parameters, void* out, - int64_t* out_len); + int32_t* out_len); // --- start Dataset interface diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 68d30a4b0aaa..6435b94cefcd 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -24,7 +24,7 @@ def _get_sample_count(total_nrow: int, params: str): sample_cnt = ctypes.c_int(0) - _safe_call(_LIB.LGBM_SampleCount( + _safe_call(_LIB.LGBM_GetSampleCount( ctypes.c_int32(total_nrow), c_str(params), ctypes.byref(sample_cnt), @@ -613,9 +613,9 @@ class Sequence(abc.ABC): # Get total row number. >>> len(seq) - # Random access by row index. Use for data sampling. + # Random access by row index. Used for data sampling. >>> seq[10] - # Range data access. Use to read data in batch when constructing Dataset. + # Range data access. Used to read data in batch when constructing Dataset. >>> seq[0:100] # Optionally specify batch_size to control range data read size. >>> seq.batch_size diff --git a/src/c_api.cpp b/src/c_api.cpp index 1b13e89ad6b2..1f98409a7a40 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -904,7 +904,7 @@ static inline std::vector CreateSampleIndices(int32_t total_nrow, const return rand.Sample(total_nrow, sample_cnt); } -int LGBM_SampleCount(int32_t total_nrow, +int LGBM_GetSampleCount(int32_t num_total_row, const char* parameters, int* out) { API_BEGIN(); @@ -915,14 +915,14 @@ int LGBM_SampleCount(int32_t total_nrow, Config config; config.Set(param); - *out = SampleCount(total_nrow, config); + *out = SampleCount(num_total_row, config); API_END(); } -int LGBM_SampleIndices(int32_t total_nrow, +int LGBM_SampleIndices(int32_t num_total_row, const char* parameters, void* out, - int64_t* out_len) { + int32_t* out_len) { // This API is to keep python binding's behavior the same with C++ implementation. // Sample count, random seed etc. should be provided in parameters. API_BEGIN(); @@ -933,9 +933,9 @@ int LGBM_SampleIndices(int32_t total_nrow, Config config; config.Set(param); - auto sample_indices = CreateSampleIndices(total_nrow, config); + auto sample_indices = CreateSampleIndices(num_total_row, config); memcpy(out, sample_indices.data(), sizeof(int32_t) * sample_indices.size()); - *out_len = sample_indices.size(); + *out_len = static_cast(sample_indices.size()); API_END(); } From 9d3a9cd4672cf1649c336f5b4c559ea1b059347a Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 28 Jun 2021 08:49:48 +0800 Subject: [PATCH 64/70] Doc about Sequence in docs/Python-Intro.rst. --- docs/Python-Intro.rst | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/docs/Python-Intro.rst b/docs/Python-Intro.rst index 2fd9c33def2e..ec3aadac0d98 100644 --- a/docs/Python-Intro.rst +++ b/docs/Python-Intro.rst @@ -39,6 +39,8 @@ The LightGBM Python module can load data from: - LightGBM binary file +- LightGBM ``Sequence`` objects + The data is stored in a ``Dataset`` object. Many of the examples in this page use functionality from ``numpy``. To run the examples, be sure to import ``numpy`` in your session. @@ -69,6 +71,38 @@ Many of the examples in this page use functionality from ``numpy``. To run the e csr = scipy.sparse.csr_matrix((dat, (row, col))) train_data = lgb.Dataset(csr) +**Load from Sequence objects:** + +We can implement ``Sequence`` interface to read binary files. The following example shows reading HDF5 file with ``h5py``. + +.. code:: python + + import h5py + + class HDFSequence(lgb.Sequence): + def __init__(self, hdf_dataset, batch_size): + self.data = hdf_dataset + self.batch_size = batch_size + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + f = h5py.File('train.hdf5', 'r') + train_data = lgb.Dataset(HDFSequence(f['X'], 8192), label=f['Y'][:]) + +Features of using ``Sequence`` interface: + +- Data sampling uses random access, thus does not go through the whole dataset +- Reading data in batch thus saves memory when constructing ``Dataset`` object +- Supports creating ``Dataset`` from multiple data files + +Please refer to ``Sequence`` `API doc `__. +`dataset_from_multi_hdf5.py `__ +is a detailed example. + **Saving Dataset into a LightGBM binary file will make loading faster:** .. code:: python From fe049f210012a7130b31601fb04911850853a54e Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 28 Jun 2021 08:50:45 +0800 Subject: [PATCH 65/70] Fix: basic.py change LGBM_SampleIndices out_len to int32. --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 6435b94cefcd..7c15a9b1d648 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -1213,7 +1213,7 @@ def _create_sample_indices(self, total_nrow: int) -> np.ndarray: sample_cnt = _get_sample_count(total_nrow, param_str) indices = np.empty(sample_cnt, dtype=np.int32) ptr_data, _, _ = c_int_array(indices) - actual_sample_cnt = ctypes.c_int64(0) + actual_sample_cnt = ctypes.c_int32(0) _safe_call(_LIB.LGBM_SampleIndices( ctypes.c_int32(total_nrow), From a3a9c4a366e96b52f0f8aa2e2203448108c81077 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Mon, 28 Jun 2021 08:51:27 +0800 Subject: [PATCH 66/70] Add create_valid test case with Dataset from Sequence. --- tests/python_package_test/test_basic.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 32038613d84f..069dd1b2a63d 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -177,16 +177,22 @@ def test_sequence(tmpdir, sample_count, batch_size, include_0_and_nan, num_seq): valid_npy_bin_fname = os.path.join(tmpdir, 'valid_data_from_npy.bin') valid_seq_bin_fname = os.path.join(tmpdir, 'valid_data_from_seq.bin') + valid_seq2_bin_fname = os.path.join(tmpdir, 'valid_data_from_seq2.bin') valid_ds = lgb.Dataset(valid_X, label=valid_Y, params=params, reference=ds) valid_ds.save_binary(valid_npy_bin_fname) + # From Dataset constructor, with dataset from numpy array. valid_seqs = _create_sequence_from_ndarray(valid_X, num_seq, batch_size) - valid_seq_ds = lgb.Dataset(valid_seqs, label=valid_Y, params=params, reference=valid_ds) + valid_seq_ds = lgb.Dataset(valid_seqs, label=valid_Y, params=params, reference=ds) valid_seq_ds.save_binary(valid_seq_bin_fname) - assert filecmp.cmp(valid_npy_bin_fname, valid_seq_bin_fname) + # From Dataset.create_valid, with dataset from sequence. + valid_seq_ds2 = seq_ds.create_valid(valid_seqs, label=valid_Y, params=params) + valid_seq_ds2.save_binary(valid_seq2_bin_fname) + assert filecmp.cmp(valid_npy_bin_fname, valid_seq2_bin_fname) + def test_chunked_dataset(): X_train, X_test, y_train, y_test = train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.1, From 519d4b96e4bd4e5888157709a8c98784bfc1d763 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Tue, 29 Jun 2021 07:48:30 +0800 Subject: [PATCH 67/70] Apply suggestions from code review Co-authored-by: Nikita Titov --- docs/Python-Intro.rst | 10 +++++----- include/LightGBM/c_api.h | 8 ++++---- src/c_api.cpp | 6 +++--- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/Python-Intro.rst b/docs/Python-Intro.rst index ec3aadac0d98..063dbf172445 100644 --- a/docs/Python-Intro.rst +++ b/docs/Python-Intro.rst @@ -39,7 +39,7 @@ The LightGBM Python module can load data from: - LightGBM binary file -- LightGBM ``Sequence`` objects +- LightGBM ``Sequence`` object(s) The data is stored in a ``Dataset`` object. @@ -96,12 +96,12 @@ We can implement ``Sequence`` interface to read binary files. The following exam Features of using ``Sequence`` interface: - Data sampling uses random access, thus does not go through the whole dataset -- Reading data in batch thus saves memory when constructing ``Dataset`` object +- Reading data in batch, thus saves memory when constructing ``Dataset`` object - Supports creating ``Dataset`` from multiple data files -Please refer to ``Sequence`` `API doc `__. -`dataset_from_multi_hdf5.py `__ -is a detailed example. +Please refer to ``Sequence`` `API doc <./Python-API.rst#data-structure-api>`__. + +`dataset_from_multi_hdf5.py `__ is a detailed example. **Saving Dataset into a LightGBM binary file will make loading faster:** diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 6ed3148848d4..f2705b8da8ab 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -61,17 +61,17 @@ LIGHTGBM_C_EXPORT int LGBM_RegisterLogCallback(void (*callback)(const char*)); * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_GetSampleCount(int32_t num_total_row, - const char* parameters, - int* out); + const char* parameters, + int* out); /*! * \brief Create sample indices for total number of rows. * \note - * You should pre-allocate memory for ``out``, you can get its length by ``LGBM_SampleCount``. + * You should pre-allocate memory for ``out``, you can get its length by ``LGBM_GetSampleCount``. * \param num_total_row Number of total rows * \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` and ``data_random_seed`` are used to produce the output * \param[out] out Created indices, type is int32_t - * \param[out] out_len Number of indices. This maybe less than the one returned by ``LGBM_SampleCount`` + * \param[out] out_len Number of indices. This maybe less than the one returned by ``LGBM_GetSampleCount`` * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t num_total_row, diff --git a/src/c_api.cpp b/src/c_api.cpp index 1f98409a7a40..26ef43184a65 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -905,11 +905,11 @@ static inline std::vector CreateSampleIndices(int32_t total_nrow, const } int LGBM_GetSampleCount(int32_t num_total_row, - const char* parameters, - int* out) { + const char* parameters, + int* out) { API_BEGIN(); if (out == nullptr) { - Log::Fatal("LGBM_SampleCount output is nullptr"); + Log::Fatal("LGBM_GetSampleCount output is nullptr"); } auto param = Config::Str2Map(parameters); Config config; From 618758a67daa478d05ed9f94bde7e1e31419c03f Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Wed, 30 Jun 2021 07:41:25 +0800 Subject: [PATCH 68/70] Apply suggestions from code review Co-authored-by: shiyu1994 --- include/LightGBM/c_api.h | 2 +- python-package/lightgbm/basic.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index f2705b8da8ab..01e5f7f8ea0c 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -71,7 +71,7 @@ LIGHTGBM_C_EXPORT int LGBM_GetSampleCount(int32_t num_total_row, * \param num_total_row Number of total rows * \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` and ``data_random_seed`` are used to produce the output * \param[out] out Created indices, type is int32_t - * \param[out] out_len Number of indices. This maybe less than the one returned by ``LGBM_GetSampleCount`` + * \param[out] out_len Number of indices. This may be less than the one returned by ``LGBM_GetSampleCount`` * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t num_total_row, diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 7c15a9b1d648..5129bf35cae1 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -609,7 +609,7 @@ class Sequence(abc.ABC): """ Generic data access interface. - Object should support the following operations:: + Object should support the following operations: # Get total row number. >>> len(seq) From 553cecc2f83033d548c4602c78552ed7f2c560a1 Mon Sep 17 00:00:00 2001 From: Chen Yufei Date: Wed, 30 Jun 2021 07:44:36 +0800 Subject: [PATCH 69/70] Remove no longer used DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT. --- python-package/lightgbm/basic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 5129bf35cae1..1959153ebda4 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -19,7 +19,6 @@ from .libpath import find_lib_path ZERO_THRESHOLD = 1e-35 -DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT = 200000 def _get_sample_count(total_nrow: int, params: str): From a614436ccd89de57ca53827d124141414230a1c8 Mon Sep 17 00:00:00 2001 From: shiyu1994 Date: Wed, 30 Jun 2021 22:17:51 +0800 Subject: [PATCH 70/70] Update python-package/lightgbm/basic.py Co-authored-by: Nikita Titov --- python-package/lightgbm/basic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 1959153ebda4..77a1ddf08fd7 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -610,6 +610,8 @@ class Sequence(abc.ABC): Object should support the following operations: + .. code-block:: + # Get total row number. >>> len(seq) # Random access by row index. Used for data sampling.