-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathcompat.py
360 lines (287 loc) · 11.6 KB
/
compat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
# coding: utf-8
"""Compatibility library."""
from typing import TYPE_CHECKING, Any, List
# scikit-learn is intentionally imported first here,
# see https://github.com/microsoft/LightGBM/issues/6509
"""sklearn"""
try:
from sklearn import __version__ as _sklearn_version
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y
try:
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
except ImportError:
from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
from sklearn.utils.validation import NotFittedError
try:
from sklearn.utils.validation import _check_sample_weight
except ImportError:
from sklearn.utils.validation import check_consistent_length
# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight
try:
from sklearn.utils.validation import validate_data
except ImportError:
# validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
# It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
def validate_data(
_estimator,
X,
y="no_validation",
accept_sparse: bool = True,
# 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
ensure_all_finite: bool = False,
ensure_min_samples: int = 1,
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
**ignored_kwargs,
):
# it's safe to import _num_features unconditionally because:
#
# * it was first added in scikit-learn 0.24.2
# * lightgbm cannot be used with scikit-learn versions older than that
# * this validate_data() re-implementation will not be called in scikit-learn>=1.6
#
from sklearn.utils.validation import _num_features
# _num_features() raises a TypeError on 1-dimensional input. That's a problem
# because scikit-learn's 'check_fit1d' estimator check sets that expectation that
# estimators must raise a ValueError when a 1-dimensional input is passed to fit().
#
# So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
if hasattr(X, "shape") and len(X.shape) == 1:
n_features_in_ = 1
else:
n_features_in_ = _num_features(X)
no_val_y = isinstance(y, str) and y == "no_validation"
# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if no_val_y:
X = check_array(
X,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)
else:
X, y = check_X_y(
X,
y,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)
# this only needs to be updated at fit() time
_estimator.n_features_in_ = n_features_in_
# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
raise ValueError(
f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
f"is expecting {_estimator._n_features} features as input."
)
if no_val_y:
return X
else:
return X, y
SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
_LGBMModelBase = BaseEstimator
_LGBMRegressorBase = RegressorMixin
_LGBMClassifierBase = ClassifierMixin
_LGBMLabelEncoder = LabelEncoder
LGBMNotFittedError = NotFittedError
_LGBMStratifiedKFold = StratifiedKFold
_LGBMGroupKFold = GroupKFold
_LGBMCheckSampleWeight = _check_sample_weight
_LGBMAssertAllFinite = assert_all_finite
_LGBMCheckClassificationTargets = check_classification_targets
_LGBMComputeSampleWeight = compute_sample_weight
_LGBMValidateData = validate_data
except ImportError:
SKLEARN_INSTALLED = False
class _LGBMModelBase: # type: ignore
"""Dummy class for sklearn.base.BaseEstimator."""
pass
class _LGBMClassifierBase: # type: ignore
"""Dummy class for sklearn.base.ClassifierMixin."""
pass
class _LGBMRegressorBase: # type: ignore
"""Dummy class for sklearn.base.RegressorMixin."""
pass
_LGBMBaseCrossValidator = None
_LGBMLabelEncoder = None
LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
_LGBMGroupKFold = None
_LGBMCheckSampleWeight = None
_LGBMAssertAllFinite = None
_LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None
_LGBMValidateData = None
_sklearn_version = None
# additional scikit-learn imports only for type hints
if TYPE_CHECKING:
# sklearn.utils.Tags can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = None
"""pandas"""
try:
from pandas import DataFrame as pd_DataFrame
from pandas import Series as pd_Series
from pandas import concat
try:
from pandas import CategoricalDtype as pd_CategoricalDtype
except ImportError:
from pandas.api.types import CategoricalDtype as pd_CategoricalDtype
PANDAS_INSTALLED = True
except ImportError:
PANDAS_INSTALLED = False
class pd_Series: # type: ignore
"""Dummy class for pandas.Series."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pd_DataFrame: # type: ignore
"""Dummy class for pandas.DataFrame."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pd_CategoricalDtype: # type: ignore
"""Dummy class for pandas.CategoricalDtype."""
def __init__(self, *args: Any, **kwargs: Any):
pass
concat = None
"""matplotlib"""
try:
import matplotlib # noqa: F401
MATPLOTLIB_INSTALLED = True
except ImportError:
MATPLOTLIB_INSTALLED = False
"""graphviz"""
try:
import graphviz # noqa: F401
GRAPHVIZ_INSTALLED = True
except ImportError:
GRAPHVIZ_INSTALLED = False
"""datatable"""
try:
import datatable
if hasattr(datatable, "Frame"):
dt_DataTable = datatable.Frame
else:
dt_DataTable = datatable.DataTable
DATATABLE_INSTALLED = True
except ImportError:
DATATABLE_INSTALLED = False
class dt_DataTable: # type: ignore
"""Dummy class for datatable.DataTable."""
def __init__(self, *args: Any, **kwargs: Any):
pass
"""dask"""
try:
from dask import delayed
from dask.array import Array as dask_Array
from dask.array import from_delayed as dask_array_from_delayed
from dask.bag import from_delayed as dask_bag_from_delayed
from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series
from dask.distributed import Client, Future, default_client, wait
DASK_INSTALLED = True
# catching 'ValueError' here because of this:
# https://github.com/microsoft/LightGBM/issues/6365#issuecomment-2002330003
#
# That's potentially risky as dask does some significant import-time processing,
# like loading configuration from environment variables and files, and catching
# ValueError here might hide issues with that config-loading.
#
# But in exchange, it's less likely that 'import lightgbm' will fail for
# dask-related reasons, which is beneficial for any workloads that are using
# lightgbm but not its Dask functionality.
except (ImportError, ValueError):
DASK_INSTALLED = False
dask_array_from_delayed = None # type: ignore[assignment]
dask_bag_from_delayed = None # type: ignore[assignment]
delayed = None
default_client = None # type: ignore[assignment]
wait = None # type: ignore[assignment]
class Client: # type: ignore
"""Dummy class for dask.distributed.Client."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class Future: # type: ignore
"""Dummy class for dask.distributed.Future."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_Array: # type: ignore
"""Dummy class for dask.array.Array."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_DataFrame: # type: ignore
"""Dummy class for dask.dataframe.DataFrame."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class dask_Series: # type: ignore
"""Dummy class for dask.dataframe.Series."""
def __init__(self, *args: Any, **kwargs: Any):
pass
"""pyarrow"""
try:
import pyarrow.compute as pa_compute
from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
from pyarrow import chunked_array as pa_chunked_array
from pyarrow.cffi import ffi as arrow_cffi
from pyarrow.types import is_boolean as arrow_is_boolean
from pyarrow.types import is_floating as arrow_is_floating
from pyarrow.types import is_integer as arrow_is_integer
PYARROW_INSTALLED = True
except ImportError:
PYARROW_INSTALLED = False
class pa_Array: # type: ignore
"""Dummy class for pa.Array."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_ChunkedArray: # type: ignore
"""Dummy class for pa.ChunkedArray."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_Table: # type: ignore
"""Dummy class for pa.Table."""
def __init__(self, *args: Any, **kwargs: Any):
pass
class arrow_cffi: # type: ignore
"""Dummy class for pyarrow.cffi.ffi."""
CData = None
addressof = None
cast = None
new = None
def __init__(self, *args: Any, **kwargs: Any):
pass
class pa_compute: # type: ignore
"""Dummy class for pyarrow.compute."""
all = None
equal = None
pa_chunked_array = None
arrow_is_boolean = None
arrow_is_integer = None
arrow_is_floating = None
"""cpu_count()"""
try:
from joblib import cpu_count
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
return cpu_count(only_physical_cores=only_physical_cores)
except ImportError:
try:
from psutil import cpu_count
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
return cpu_count(logical=not only_physical_cores) or 1
except ImportError:
from multiprocessing import cpu_count
def _LGBMCpuCount(only_physical_cores: bool = True) -> int:
return cpu_count()
__all__: List[str] = []