Skip to content

Commit

Permalink
[python-package] Create Dataset from multiple data files (#4089)
Browse files Browse the repository at this point in the history
* [python-package] create Dataset from sampled data.

* [python-package] create Dataset from List[Sequence].

1. Use random access for data sampling
2. Support read data from multiple input files
3. Read data in batch so no need to hold all data in memory

* [python-package] example: create Dataset from multiple HDF5 file.

* fix: revert is_class implementation for seq

* fix: unwanted memory view reference for seq

* fix: seq is_class accepts sklearn matrices

* fix: requirements for example

* fix: pycode

* feat: print static code linting stage

* fix: linting: avoid shell str regex conversion

* code style: doc style

* code style: isort

* fix ci dependency: h5py on windows

* [py] remove rm files in test seq
#4089 (comment)

* docs(python): init_from_sample summary

#4089 (comment)

* remove dataset dump sample data debugging code.

* remove typo fix.

Create separate PR for this.

* fix typo in src/c_api.cpp

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* style(linting): py3 type hint for seq

* test(basic): os.path style path handling

* Revert "feat: print static code linting stage"

This reverts commit 10bd79f.

* feat(python): sequence on validation set

* minor(python): comment

* minor(python): test option hint

* style(python): fix code linting

* style(python): add pydoc for ref_dataset

* doc(python): sequence

Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>

* revert(python): sequence class abc

* chore(python): remove rm_files

* Remove useless static_assert.

* refactor: test_basic test for sequence.

* fix lint complaint.

* remove dataset._dump_text in sequence test.

* Fix reverting typo fix.

* Apply suggestions from code review

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Fix type hint, code and doc style.

* fix failing test_basic.

* Remove TODO about keep constant in sync with cpp.

* Install h5py only when running python-examples.

* Fix lint complaint.

* Apply suggestions from code review

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* Doc fixes, remove unused params_str in __init_from_seqs.

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Remove unnecessary conda install in windows ci script.

* Keep param as example in dataset_from_multi_hdf5.py

* Add _get_sample_count function to remove code duplication.

* Use batch_size parameter in generate_hdf.

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Fix after applying suggestions.

* Fix test, check idx is instance of numbers.Integral.

* Update python-package/lightgbm/basic.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Expose Sequence class in Python-API doc.

* Handle Sequence object not having batch_size.

* Fix isort lint complaint.

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update docstring to mention Sequence as data input.

* Remove get_one_line in test_basic.py

* Make Sequence an abstract class.

* Reduce number of tests for test_sequence.

* Add c_api: LGBM_SampleCount, fix potential bug in LGBMSampleIndices.

* empty commit to trigger ci

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Rename to LGBM_GetSampleCount, change LGBM_SampleIndices out_len to int32_t.

Also rename total_nrow to num_total_row in c_api.h for consistency.

* Doc about Sequence in docs/Python-Intro.rst.

* Fix: basic.py change LGBM_SampleIndices out_len to int32.

* Add create_valid test case with Dataset from Sequence.

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Apply suggestions from code review

Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>

* Remove no longer used DEFAULT_BIN_CONSTRUCT_SAMPLE_CNT.

* Update python-package/lightgbm/basic.py

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

Co-authored-by: Willian Zhang <willian@willian.email>
Co-authored-by: Willian Z <Willian@Willian-Zhang.com>
Co-authored-by: James Lamb <jaylamb20@gmail.com>
Co-authored-by: shiyu1994 <shiyu_k1994@qq.com>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
6 people authored Jul 2, 2021
1 parent f37b0d4 commit c359896
Show file tree
Hide file tree
Showing 11 changed files with 625 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ import matplotlib\
matplotlib.use\(\"Agg\"\)\
' plot_example.py # prevent interactive window mode
sed -i'.bak' 's/graph.render(view=True)/graph.render(view=False)/' plot_example.py
conda install -q -y -n $CONDA_ENV h5py ipywidgets notebook # requirements for examples
for f in *.py **/*.py; do python $f || exit -1; done # run all examples
cd $BUILD_DIRECTORY/examples/python-guide/notebooks
conda install -q -y -n $CONDA_ENV ipywidgets notebook
jupyter nbconvert --ExecutePreprocessor.timeout=180 --to notebook --execute --inplace *.ipynb || exit -1 # run all notebooks
fi
2 changes: 1 addition & 1 deletion .ci/test_windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ if (($env:TASK -eq "regular") -or (($env:APPVEYOR -eq "true") -and ($env:TASK -e
cd $env:BUILD_SOURCESDIRECTORY/examples/python-guide
@("import matplotlib", "matplotlib.use('Agg')") + (Get-Content "plot_example.py") | Set-Content "plot_example.py"
(Get-Content "plot_example.py").replace('graph.render(view=True)', 'graph.render(view=False)') | Set-Content "plot_example.py" # prevent interactive window mode
conda install -q -y -n $env:CONDA_ENV h5py ipywidgets notebook
foreach ($file in @(Get-ChildItem *.py)) {
@("import sys, warnings", "warnings.showwarning = lambda message, category, filename, lineno, file=None, line=None: sys.stdout.write(warnings.formatwarning(message, category, filename, lineno, line))") + (Get-Content $file) | Set-Content $file
python $file ; Check-Output $?
} # run all examples
cd $env:BUILD_SOURCESDIRECTORY/examples/python-guide/notebooks
conda install -q -y -n $env:CONDA_ENV ipywidgets notebook
jupyter nbconvert --ExecutePreprocessor.timeout=180 --to notebook --execute --inplace *.ipynb ; Check-Output $? # run all notebooks
}
1 change: 1 addition & 0 deletions docs/Python-API.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Data Structure API
Dataset
Booster
CVBooster
Sequence

Training API
------------
Expand Down
34 changes: 34 additions & 0 deletions docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ The LightGBM Python module can load data from:

- LightGBM binary file

- LightGBM ``Sequence`` object(s)

The data is stored in a ``Dataset`` object.

Many of the examples in this page use functionality from ``numpy``. To run the examples, be sure to import ``numpy`` in your session.
Expand Down Expand Up @@ -69,6 +71,38 @@ Many of the examples in this page use functionality from ``numpy``. To run the e
csr = scipy.sparse.csr_matrix((dat, (row, col)))
train_data = lgb.Dataset(csr)
**Load from Sequence objects:**

We can implement ``Sequence`` interface to read binary files. The following example shows reading HDF5 file with ``h5py``.

.. code:: python
import h5py
class HDFSequence(lgb.Sequence):
def __init__(self, hdf_dataset, batch_size):
self.data = hdf_dataset
self.batch_size = batch_size
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
f = h5py.File('train.hdf5', 'r')
train_data = lgb.Dataset(HDFSequence(f['X'], 8192), label=f['Y'][:])
Features of using ``Sequence`` interface:

- Data sampling uses random access, thus does not go through the whole dataset
- Reading data in batch, thus saves memory when constructing ``Dataset`` object
- Supports creating ``Dataset`` from multiple data files

Please refer to ``Sequence`` `API doc <./Python-API.rst#data-structure-api>`__.

`dataset_from_multi_hdf5.py <https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/dataset_from_multi_hdf5.py>`__ is a detailed example.

**Saving Dataset into a LightGBM binary file will make loading faster:**

.. code:: python
Expand Down
3 changes: 3 additions & 0 deletions examples/python-guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ Examples include:
- Plot split value histogram
- Plot one specified tree
- Plot one specified tree with Graphviz
- [dataset_from_multi_hdf5.py](https://github.com/microsoft/LightGBM/blob/master/examples/python-guide/dataset_from_multi_hdf5.py)
- Construct Dataset from multiple HDF5 files
- Avoid loading all data into memory
106 changes: 106 additions & 0 deletions examples/python-guide/dataset_from_multi_hdf5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import h5py
import numpy as np
import pandas as pd

import lightgbm as lgb


class HDFSequence(lgb.Sequence):
def __init__(self, hdf_dataset, batch_size):
"""
Construct a sequence object from HDF5 with required interface.
Parameters
----------
hdf_dataset : h5py.Dataset
Dataset in HDF5 file.
batch_size : int
Size of a batch. When reading data to construct lightgbm Dataset, each read reads batch_size rows.
"""
# We can also open HDF5 file once and get access to
self.data = hdf_dataset
self.batch_size = batch_size

def __getitem__(self, idx):
return self.data[idx]

def __len__(self):
return len(self.data)


def create_dataset_from_multiple_hdf(input_flist, batch_size):
data = []
ylist = []
for f in input_flist:
f = h5py.File(f, 'r')
data.append(HDFSequence(f['X'], batch_size))
ylist.append(f['Y'][:])

params = {
'bin_construct_sample_cnt': 200000,
'max_bin': 255,
}
y = np.concatenate(ylist)
dataset = lgb.Dataset(data, label=y, params=params)
# With binary dataset created, we can use either Python API or cmdline version to train.
#
# Note: in order to create exactly the same dataset with the one created in simple_example.py, we need
# to modify simple_example.py to pass numpy array instead of pandas DataFrame to Dataset constructor.
# The reason is that DataFrame column names will be used in Dataset. For a DataFrame with Int64Index
# as columns, Dataset will use column names like ["0", "1", "2", ...]. While for numpy array, column names
# are using the default one assigned in C++ code (dataset_loader.cpp), like ["Column_0", "Column_1", ...].
dataset.save_binary('regression.train.from_hdf.bin')


def save2hdf(input_data, fname, batch_size):
"""Store numpy array to HDF5 file.
Please note chunk size settings in the implementation for I/O performance optimization.
"""
with h5py.File(fname, 'w') as f:
for name, data in input_data.items():
nrow, ncol = data.shape
if ncol == 1:
# Y has a single column and we read it in single shot. So store it as an 1-d array.
chunk = (nrow,)
data = data.values.flatten()
else:
# We use random access for data sampling when creating LightGBM Dataset from Sequence.
# When accessing any element in a HDF5 chunk, it's read entirely.
# To save I/O for sampling, we should keep number of total chunks much larger than sample count.
# Here we are just creating a chunk size that matches with batch_size.
#
# Also note that the data is stored in row major order to avoid extra copy when passing to
# lightgbm Dataset.
chunk = (batch_size, ncol)
f.create_dataset(name, data=data, chunks=chunk, compression='lzf')


def generate_hdf(input_fname, output_basename, batch_size):
# Save to 2 HDF5 files for demonstration.
df = pd.read_csv(input_fname, header=None, sep='\t')

mid = len(df) // 2
df1 = df.iloc[:mid]
df2 = df.iloc[mid:]

# We can store multiple datasets inside a single HDF5 file.
# Separating X and Y for choosing best chunk size for data loading.
fname1 = f'{output_basename}1.h5'
fname2 = f'{output_basename}2.h5'
save2hdf({'Y': df1.iloc[:, :1], 'X': df1.iloc[:, 1:]}, fname1, batch_size)
save2hdf({'Y': df2.iloc[:, :1], 'X': df2.iloc[:, 1:]}, fname2, batch_size)

return [fname1, fname2]


def main():
batch_size = 64
output_basename = 'regression'
hdf_files = generate_hdf('../regression/regression.train', output_basename, batch_size)

create_dataset_from_multiple_hdf(hdf_files, batch_size=batch_size)


if __name__ == '__main__':
main()
26 changes: 26 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,32 @@ LIGHTGBM_C_EXPORT const char* LGBM_GetLastError();
*/
LIGHTGBM_C_EXPORT int LGBM_RegisterLogCallback(void (*callback)(const char*));

/*!
* \brief Get number of samples based on parameters and total number of rows of data.
* \param num_total_row Number of total rows
* \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` is used to calculate returned value
* \param[out] out Number of samples. This value is used to pre-allocate memory to hold sample indices when calling ``LGBM_SampleIndices``
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_GetSampleCount(int32_t num_total_row,
const char* parameters,
int* out);

/*!
* \brief Create sample indices for total number of rows.
* \note
* You should pre-allocate memory for ``out``, you can get its length by ``LGBM_GetSampleCount``.
* \param num_total_row Number of total rows
* \param parameters Additional parameters, namely, ``bin_construct_sample_cnt`` and ``data_random_seed`` are used to produce the output
* \param[out] out Created indices, type is int32_t
* \param[out] out_len Number of indices. This may be less than the one returned by ``LGBM_GetSampleCount``
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_SampleIndices(int32_t num_total_row,
const char* parameters,
void* out,
int32_t* out_len);

// --- start Dataset interface

/*!
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import os

from .basic import Booster, Dataset, register_logger
from .basic import Booster, Dataset, Sequence, register_logger
from .callback import early_stopping, print_evaluation, record_evaluation, reset_parameter
from .engine import CVBooster, cv, train

Expand All @@ -29,7 +29,7 @@
with open(os.path.join(dir_path, 'VERSION.txt')) as version_file:
__version__ = version_file.read().strip()

__all__ = ['Dataset', 'Booster', 'CVBooster',
__all__ = ['Dataset', 'Booster', 'CVBooster', 'Sequence',
'register_logger',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
Expand Down
Loading

0 comments on commit c359896

Please sign in to comment.