Skip to content

Commit

Permalink
[python-package] consolidate pandas-to-numpy conversion code
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 27, 2023
1 parent fcf76bc commit 94de782
Showing 1 changed file with 20 additions and 28 deletions.
48 changes: 20 additions & 28 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,24 @@ def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')


def _pandas_to_numpy(data: pd_DataFrame) -> np.ndarray:
_check_for_bad_pandas_dtypes(data.dtypes)
df_dtypes = [dtype.type for dtype in data.dtypes]
# so that the target dtype considers floats
df_dtypes.append(np.float32)
target_dtype = np.result_type(*df_dtypes)
try:
# most common case (no nullable dtypes)
return data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
return data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
return data.to_numpy(dtype=target_dtype, na_value=np.nan)


def _data_from_pandas(
data: pd_DataFrame,
feature_name: _LGBM_FeatureNameConfiguration,
Expand Down Expand Up @@ -721,22 +739,7 @@ def _data_from_pandas(
else: # use cat cols specified by user
categorical_feature = list(categorical_feature) # type: ignore[assignment]

# get numpy representation of the data
_check_for_bad_pandas_dtypes(data.dtypes)
df_dtypes = [dtype.type for dtype in data.dtypes]
df_dtypes.append(np.float32) # so that the target dtype considers floats
target_dtype = np.result_type(*df_dtypes)
try:
# most common case (no nullable dtypes)
data = data.to_numpy(dtype=target_dtype, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
data = data.astype(target_dtype, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
data = data.to_numpy(dtype=target_dtype, na_value=np.nan)
return data, feature_name, categorical_feature, pandas_categorical
return _pandas_to_numpy(data), feature_name, categorical_feature, pandas_categorical


def _dump_pandas_categorical(
Expand Down Expand Up @@ -2678,18 +2681,7 @@ def set_label(self, label: Optional[_LGBM_LabelType]) -> "Dataset":
if isinstance(label, pd_DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
_check_for_bad_pandas_dtypes(label.dtypes)
try:
# most common case (no nullable dtypes)
label = label.to_numpy(dtype=np.float32, copy=False)
except TypeError:
# 1.0 <= pd version < 1.1 and nullable dtypes, least common case
# raises error because array is casted to type(pd.NA) and there's no na_value argument
label = label.astype(np.float32, copy=False).values
except ValueError:
# data has nullable dtypes, but we can specify na_value argument and copy will be made
label = label.to_numpy(dtype=np.float32, na_value=np.nan)
label_array = np.ravel(label)
label_array = np.ravel(_pandas_to_numpy(label))
else:
label_array = _list_to_1d_numpy(label, dtype=np.float32, name='label')
self.set_field('label', label_array)
Expand Down

0 comments on commit 94de782

Please sign in to comment.