Skip to content

Commit

Permalink
[python-package] Infer feature names from pyarrow.Table (#6781)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlondschien authored Jan 11, 2025
1 parent e0c34e7 commit e61bcbe
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,8 @@ def _lazy_init(
categorical_feature=categorical_feature,
pandas_categorical=self.pandas_categorical,
)
elif _is_pyarrow_table(data) and feature_name == "auto":
feature_name = data.column_names

# process for args
params = {} if params is None else params
Expand Down Expand Up @@ -2185,7 +2187,6 @@ def _lazy_init(
self.__init_from_np2d(data, params_str, ref_dataset)
elif _is_pyarrow_table(data):
self.__init_from_pyarrow_table(data, params_str, ref_dataset)
feature_name = data.column_names
elif isinstance(data, list) and len(data) > 0:
if _is_list_of_numpy_arrays(data):
self.__init_from_list_np2d(data, params_str, ref_dataset)
Expand Down
22 changes: 22 additions & 0 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,25 @@ def test_predict_ranking():
num_boost_round=5,
)
assert_equal_predict_arrow_pandas(booster, data)


def test_arrow_feature_name_auto():
data = generate_dummy_arrow_table()
dataset = lgb.Dataset(
data, label=pa.array([0, 1, 0, 0, 1]), params=dummy_dataset_params(), categorical_feature=["a"]
)
booster = lgb.train({"num_leaves": 7}, dataset, num_boost_round=5)
assert booster.feature_name() == ["a", "b"]


def test_arrow_feature_name_manual():
data = generate_dummy_arrow_table()
dataset = lgb.Dataset(
data,
label=pa.array([0, 1, 0, 0, 1]),
params=dummy_dataset_params(),
feature_name=["c", "d"],
categorical_feature=["c"],
)
booster = lgb.train({"num_leaves": 7}, dataset, num_boost_round=5)
assert booster.feature_name() == ["c", "d"]

0 comments on commit e61bcbe

Please sign in to comment.