Skip to content

Commit b3db3cc

Browse files
authored
feat(python): lance write huggingface dataset directly (#1882)
Be able to directly write a huggingface dataset
1 parent 1365378 commit b3db3cc

File tree

6 files changed

+86
-12
lines changed

6 files changed

+86
-12
lines changed

docs/integrations/huggingface.rst

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Lance ❤️ HuggingFace
2+
--------------------
3+
4+
The HuggingFace Hub has become the go to place for ML practitioners to find pre-trained models and useful datasets.
5+
6+
HuggingFace datasets can be written directly into Lance format by using the
7+
:meth:`lance.write_dataset` method. You can write the entire dataset or a particular split. For example:
8+
9+
10+
.. code-block:: python
11+
12+
# Huggingface datasets
13+
import datasets
14+
import lance
15+
16+
lance.write_dataset(datasets.load_dataset(
17+
"poloclub/diffusiondb", split="train[:10]",
18+
), "diffusiondb_train.lance")

docs/integrations/integrations.rst

+1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ Integrations
33

44
.. toctree::
55

6+
Huggingface <./huggingface>
67
Tensorflow <./tensorflow>

python/pyproject.toml

+7-11
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
name = "pylance"
33
dependencies = ["pyarrow>=12", "numpy>=1.22"]
44
description = "python wrapper for Lance columnar format"
5-
authors = [
6-
{ name = "Lance Devs", email = "dev@lancedb.com" },
7-
]
5+
authors = [{ name = "Lance Devs", email = "dev@lancedb.com" }]
86
license = { file = "LICENSE" }
97
repository = "https://github.com/eto-ai/lance"
108
readme = "README.md"
@@ -48,20 +46,18 @@ build-backend = "maturin"
4846

4947
[project.optional-dependencies]
5048
tests = [
51-
"pandas",
52-
"pytest",
49+
"datasets",
5350
"duckdb",
5451
"ml_dtypes",
52+
"pillow",
53+
"pandas",
5554
"polars[pyarrow,pandas]",
55+
"pytest",
5656
"tensorflow",
5757
"tqdm",
5858
]
59-
benchmarks = [
60-
"pytest-benchmark",
61-
]
62-
torch = [
63-
"torch",
64-
]
59+
benchmarks = ["pytest-benchmark"]
60+
torch = ["torch"]
6561

6662
[tool.ruff]
6763
select = ["F", "E", "W", "I", "G", "TCH", "PERF", "CPY001", "B019"]

python/python/lance/dataset.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@
4141
import pyarrow.dataset
4242
from pyarrow import RecordBatch, Schema
4343

44-
from .dependencies import _check_for_numpy, _check_for_pandas, torch
44+
from .dependencies import (
45+
_check_for_hugging_face,
46+
_check_for_numpy,
47+
_check_for_pandas,
48+
torch,
49+
)
4550
from .dependencies import numpy as np
4651
from .dependencies import pandas as pd
4752
from .fragment import FragmentMetadata, LanceFragment
@@ -1992,6 +1997,7 @@ def write_dataset(
19921997
data_obj: Reader-like
19931998
The data to be written. Acceptable types are:
19941999
- Pandas DataFrame, Pyarrow Table, Dataset, Scanner, or RecordBatchReader
2000+
- Huggingface dataset
19952001
uri: str or Path
19962002
Where to write the dataset to (directory)
19972003
schema: Schema, optional
@@ -2020,6 +2026,15 @@ def write_dataset(
20202026
a custom class that defines hooks to be called when each fragment is
20212027
starting to write and finishing writing.
20222028
"""
2029+
if _check_for_hugging_face(data_obj):
2030+
# Huggingface datasets
2031+
from .dependencies import datasets
2032+
2033+
if isinstance(data_obj, datasets.Dataset):
2034+
if schema is None:
2035+
schema = data_obj.features.arrow_schema
2036+
data_obj = data_obj.data.to_batches()
2037+
20232038
reader = _coerce_reader(data_obj, schema)
20242039
_validate_schema(reader.schema)
20252040
# TODO add support for passing in LanceDataset and LanceScanner here

python/python/lance/dependencies.py

+11
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_PANDAS_AVAILABLE = True
3535
_POLARS_AVAILABLE = True
3636
_TORCH_AVAILABLE = True
37+
_HUGGING_FACE_AVAILABLE = True
3738

3839

3940
class _LazyModule(ModuleType):
@@ -164,6 +165,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
164165

165166

166167
if TYPE_CHECKING:
168+
import datasets
167169
import numpy
168170
import pandas
169171
import polars
@@ -174,6 +176,7 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]:
174176
pandas, _PANDAS_AVAILABLE = _lazy_import("pandas")
175177
polars, _POLARS_AVAILABLE = _lazy_import("polars")
176178
torch, _TORCH_AVAILABLE = _lazy_import("torch")
179+
datasets, _HUGGING_FACE_AVAILABLE = _lazy_import("datasets")
177180

178181

179182
@lru_cache(maxsize=None)
@@ -210,6 +213,12 @@ def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
210213
)
211214

212215

216+
def _check_for_hugging_face(obj: Any, *, check_type: bool = True) -> bool:
217+
return _HUGGING_FACE_AVAILABLE and _might_be(
218+
cast(Hashable, type(obj) if check_type else obj), "datasets"
219+
)
220+
221+
213222
__all__ = [
214223
# lazy-load third party libs
215224
"numpy",
@@ -221,10 +230,12 @@ def _check_for_torch(obj: Any, *, check_type: bool = True) -> bool:
221230
"_check_for_pandas",
222231
"_check_for_polars",
223232
"_check_for_torch",
233+
"_check_for_hugging_face",
224234
"_LazyModule",
225235
# exported flags/guards
226236
"_NUMPY_AVAILABLE",
227237
"_PANDAS_AVAILABLE",
228238
"_POLARS_AVAILABLE",
229239
"_TORCH_AVAILABLE",
240+
"_HUGGING_FACE_AVAILABLE",
230241
]
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023 Lance Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
from pathlib import Path
15+
16+
import lance
17+
import pytest
18+
19+
datasets = pytest.importorskip("datasets")
20+
21+
22+
def test_write_hf_dataset(tmp_path: Path):
23+
hf_ds = datasets.load_dataset(
24+
"poloclub/diffusiondb",
25+
name="2m_first_1k",
26+
split="train[:50]",
27+
trust_remote_code=True,
28+
)
29+
30+
ds = lance.write_dataset(hf_ds, tmp_path)
31+
assert ds.count_rows() == 50
32+
33+
assert ds.schema == hf_ds.features.arrow_schema

0 commit comments

Comments
 (0)