Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] [no_early_kickoff] Add column API to Dataset #35241

Merged
merged 7 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/data/api/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ Inspecting Metadata
:toctree: doc/

Dataset.count
Dataset.columns
Dataset.schema
Dataset.default_batch_format
Dataset.num_blocks
Expand Down
37 changes: 33 additions & 4 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,6 +2158,39 @@ def schema(self, fetch_if_missing: bool = True) -> Optional["Schema"]:
else:
return base_schema

@ConsumptionAPI(
if_more_than_read=True,
datasource_metadata="schema",
extra_condition="or if ``fetch_if_missing=True`` (the default)",
pattern="Time complexity:",
)
def columns(self, fetch_if_missing: bool = True) -> Optional[List[str]]:
"""Returns the columns of this Dataset.

Time complexity: O(1)

Example:
>>> import ray
>>> # Create dataset from synthetic data.
>>> ds = ray.data.range(1000)
>>> ds.columns()
['id']

Args:
fetch_if_missing: If True, synchronously fetch the column names from the
schema if it's not known. If False, None is returned if the schema is
not known. Default is True.

Returns:
A list of the column names for this Dataset or None if schema is not known
and `fetch_if_missing` is False.

"""
schema = self.schema(fetch_if_missing=fetch_if_missing)
if schema is not None:
return schema.names
return None

def num_blocks(self) -> int:
"""Return the number of blocks of this dataset.

Expand Down Expand Up @@ -4361,10 +4394,6 @@ def __del__(self):
self._current_executor.shutdown()


# Backwards compatibility alias.
Dataset = Dataset


@PublicAPI
class MaterializedDataset(Dataset, Generic[T]):
"""A Dataset materialized in Ray memory, e.g., via `.materialize()`.
Expand Down
9 changes: 9 additions & 0 deletions python/ray/data/tests/test_consumption.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,15 @@ def test_schema_lazy(ray_start_regular_shared):
assert ds._plan.execute()._num_computed() == 0


def test_columns(ray_start_regular_shared):
ds = ray.data.range(1)
assert ds.columns() == ds.schema().names
assert ds.columns() == ["id"]

ds = ds.map(lambda x: x)
assert ds.columns(fetch_if_missing=False) is None


def test_schema_repr(ray_start_regular_shared):
ds = ray.data.from_items([{"text": "spam", "number": 0}])
# fmt: off
Expand Down