From aab212a78285fb245cf980d39e62a1048af5005f Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Fri, 5 Nov 2021 20:29:49 +0300 Subject: [PATCH] [python][sklearn] add `n_estimators_` and `n_iter_` post-fit attributes (#4753) * add n_estimators_ and n_iter_ post-fit attributes * address review comments --- python-package/lightgbm/sklearn.py | 22 ++++++++++++++++++++++ tests/python_package_test/test_sklearn.py | 11 +++++++++++ 2 files changed, 33 insertions(+) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 4981df873c7f..05a738292711 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -847,6 +847,28 @@ def objective_(self): raise LGBMNotFittedError('No objective found. Need to call fit beforehand.') return self._objective + @property + def n_estimators_(self) -> int: + """:obj:`int`: True number of boosting iterations performed. + + This might be less than parameter ``n_estimators`` if early stopping was enabled or + if boosting stopped early due to limits on complexity like ``min_gain_to_split``. + """ + if not self.__sklearn_is_fitted__(): + raise LGBMNotFittedError('No n_estimators found. Need to call fit beforehand.') + return self._Booster.current_iteration() + + @property + def n_iter_(self) -> int: + """:obj:`int`: True number of boosting iterations performed. + + This might be less than parameter ``n_estimators`` if early stopping was enabled or + if boosting stopped early due to limits on complexity like ``min_gain_to_split``. + """ + if not self.__sklearn_is_fitted__(): + raise LGBMNotFittedError('No n_iter found. Need to call fit beforehand.') + return self._Booster.current_iteration() + @property def booster_(self): """Booster: The underlying Booster of this model.""" diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 152757c79634..4204ffb4ec0f 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1158,6 +1158,17 @@ def test_continue_training_with_model(): assert gbm.evals_result_['valid_0']['multi_logloss'][-1] < init_gbm.evals_result_['valid_0']['multi_logloss'][-1] +def test_actual_number_of_trees(): + X = [[1, 2, 3], [1, 2, 3]] + y = [1, 1] + n_estimators = 5 + gbm = lgb.LGBMRegressor(n_estimators=n_estimators).fit(X, y) + assert gbm.n_estimators == n_estimators + assert gbm.n_estimators_ == 1 + assert gbm.n_iter_ == 1 + np.testing.assert_array_equal(gbm.predict(np.array(X) * 10), y) + + # sklearn < 0.22 requires passing "attributes" argument @pytest.mark.skipif(sk_version < parse_version('0.22'), reason='scikit-learn version is less than 0.22') def test_check_is_fitted():