Skip to content

Commit

Permalink
[Data] Deprecate dataset_format (#33437)
Browse files Browse the repository at this point in the history
Deprecates dataset_format.

Makes the following changes in order to do so:

- Preprocessors can implement a preferred_batch_format just like Predictor instead of querying dataset_format.
- KeyFn validation for sort operator can be done entirely through schema instead of dataset_format
- to_arrow_refs looks at the schema directly
- Use runtime checks for RandomAccessDataset
- fast_repartition uses schema directly

---------

Signed-off-by: amogkam <amogkamsetty@yahoo.com>
  • Loading branch information
amogkam authored Mar 22, 2023
1 parent c50ecbc commit f8f7374
Show file tree
Hide file tree
Showing 20 changed files with 341 additions and 279 deletions.
22 changes: 17 additions & 5 deletions python/ray/data/_internal/fast_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def fast_repartition(blocks, num_blocks, ctx: Optional[TaskContext] = None):
)
# Compute the (n-1) indices needed for an equal split of the data.
count = wrapped_ds.count()
dataset_format = wrapped_ds.dataset_format()
indices = []
cur_idx = 0
for _ in range(num_blocks - 1):
Expand Down Expand Up @@ -59,6 +58,9 @@ def fast_repartition(blocks, num_blocks, ctx: Optional[TaskContext] = None):

owned_by_consumer = blocks._owned_by_consumer

# Schema is safe to fetch here since we have already called
# get_internal_block_refs and executed the dataset.
schema = wrapped_ds.schema(fetch_if_missing=True)
# Early-release memory.
del splits, blocks, wrapped_ds

Expand All @@ -75,13 +77,23 @@ def fast_repartition(blocks, num_blocks, ctx: Optional[TaskContext] = None):
from ray.data._internal.pandas_block import PandasBlockBuilder
from ray.data._internal.simple_block import SimpleBlockBuilder

import pyarrow as pa
from ray.data._internal.pandas_block import PandasBlockSchema

num_empties = num_blocks - len(new_blocks)
if dataset_format == "arrow":

if schema is None:
raise ValueError(
"Dataset is empty or cleared, can't determine the format of "
"the dataset."
)
elif isinstance(schema, type):
builder = SimpleBlockBuilder()
elif isinstance(schema, pa.Schema):
builder = ArrowBlockBuilder()
elif dataset_format == "pandas":
elif isinstance(schema, PandasBlockSchema):
builder = PandasBlockBuilder()
else:
builder = SimpleBlockBuilder()

empty_block = builder.build()
empty_meta = BlockAccessor.for_block(empty_block).get_metadata(
input_files=None, exec_stats=None
Expand Down
20 changes: 9 additions & 11 deletions python/ray/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,35 +60,33 @@

def _validate_key_fn(ds: "Dataset", key: KeyFn) -> None:
"""Check the key function is valid on the given dataset."""
try:
fmt = ds.dataset_format()
except ValueError:
schema = ds.schema(fetch_if_missing=True)
if schema is None:
# Dataset is empty/cleared, validation not possible.
return
is_simple_format = isinstance(schema, type)
if isinstance(key, str):
if fmt == "simple":
if is_simple_format:
raise ValueError(
"String key '{}' requires dataset format to be "
"'arrow' or 'pandas', was '{}'.".format(key, fmt)
"'arrow' or 'pandas', was 'simple'.".format(key)
)
# Raises KeyError if key is not present in the schema.
schema = ds.schema(fetch_if_missing=True)
if len(schema.names) > 0 and key not in schema.names:
raise ValueError(
"The column '{}' does not exist in the "
"schema '{}'.".format(key, schema)
)
elif key is None:
if fmt != "simple":
if not is_simple_format:
raise ValueError(
"The `None` key '{}' requires dataset format to be "
"'simple', was '{}'.".format(key, fmt)
"'simple'.".format(key)
)
elif callable(key):
if fmt != "simple":
if not is_simple_format:
raise ValueError(
"Callable key '{}' requires dataset format to be "
"'simple', was '{}'.".format(key, fmt)
"'simple'".format(key)
)
else:
raise TypeError("Invalid key type {} ({}).".format(key, type(key)))
Expand Down
29 changes: 17 additions & 12 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3641,9 +3641,13 @@ def to_arrow_refs(self) -> List[ObjectRef["pyarrow.Table"]]:
Returns:
A list of remote Arrow tables created from this dataset.
"""
blocks: List[ObjectRef[Block]] = self.get_internal_block_refs()
import pyarrow as pa

if self.dataset_format() == BlockFormat.ARROW:
blocks: List[ObjectRef["pyarrow.Table"]] = self.get_internal_block_refs()
# Schema is safe to call since we have already triggered execution with
# get_internal_block_refs.
schema = self.schema(fetch_if_missing=True)
if isinstance(schema, pa.Schema):
# Zero-copy path.
return blocks

Expand Down Expand Up @@ -4240,13 +4244,22 @@ def default_batch_format(self) -> Type:
pattern="for the first block.",
insert_after=True,
)
@Deprecated(message="`dataset_format` is deprecated for streaming execution.")
def dataset_format(self) -> BlockFormat:
"""The format of the dataset's underlying data blocks. Possible values
are: "arrow", "pandas" and "simple".
This may block; if the schema is unknown, this will synchronously fetch
the schema for the first block.
"""
context = DatasetContext.get_current()
if context.use_streaming_executor:
raise DeprecationWarning(
"`dataset_format` is deprecated for streaming execution. To use "
"`dataset_format`, you must explicitly enable bulk execution by "
"setting `use_streaming_executor` to False in the `DatasetContext`"
)

# We need schema to properly validate, so synchronously
# fetch it if necessary.
schema = self.schema(fetch_if_missing=True)
Expand Down Expand Up @@ -4294,18 +4307,10 @@ def _build_multicolumn_aggs(
"""Build set of aggregations for applying a single aggregation to
multiple columns.
"""

# Expand None into an aggregation for each column.
if on is None:
try:
dataset_format = self.dataset_format()
except ValueError:
dataset_format = None
if dataset_format in [BlockFormat.ARROW, BlockFormat.PANDAS]:
# This should be cached from the .dataset_format() check, so we
# don't fetch and we assert that the schema is not None.
schema = self.schema(fetch_if_missing=False)
assert schema is not None
schema = self.schema(fetch_if_missing=True)
if schema is not None and not isinstance(schema, type):
if not skip_cols:
skip_cols = []
if len(schema.names) > 0:
Expand Down
93 changes: 36 additions & 57 deletions python/ray/data/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union, Dict, Any

from ray.air.util.data_batch_conversion import BatchFormat, BlockFormat
from ray.air.util.data_batch_conversion import BatchFormat
from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI

if TYPE_CHECKING:
Expand Down Expand Up @@ -216,69 +216,41 @@ def _fit(self, dataset: "Dataset") -> "Preprocessor":
"""Sub-classes should override this instead of fit()."""
raise NotImplementedError()

def _determine_transform_to_use(self, data_format: BlockFormat) -> BatchFormat:
"""Determine which transform to use based on data format and implementation.
def _determine_transform_to_use(self) -> BatchFormat:
"""Determine which batch format to use based on Preprocessor implementation.
We will infer and pick the best transform to use:
* ``pandas`` data format prioritizes ``pandas`` transform if available.
* ``arrow`` and ``numpy`` data format prioritizes ``numpy`` transform if available. # noqa: E501
* Fall back to what's available if no preferred path found.
* If only `_transform_pandas` is implemented, then use ``pandas`` batch format.
* If only `_transform_numpy` is implemented, then use ``numpy`` batch format.
* If both are implemented, then use the Preprocessor defined preferred batch
format.
"""

assert data_format in (
"pandas",
"arrow",
"numpy",
), f"Unsupported data format: {data_format}"

has_transform_pandas = (
self.__class__._transform_pandas != Preprocessor._transform_pandas
)
has_transform_numpy = (
self.__class__._transform_numpy != Preprocessor._transform_numpy
)

# Infer transform type by prioritizing native transformation to minimize
# data conversion cost.
if data_format == BlockFormat.PANDAS:
# Perform native pandas transformation if possible.
if has_transform_pandas:
transform_type = BatchFormat.PANDAS
elif has_transform_numpy:
transform_type = BatchFormat.NUMPY
else:
raise NotImplementedError(
"None of `_transform_numpy` or `_transform_pandas` "
f"are implemented for dataset format `{data_format}`."
)
elif data_format == BlockFormat.ARROW or data_format == "numpy":
# Arrow -> Numpy is more efficient
if has_transform_numpy:
transform_type = BatchFormat.NUMPY
elif has_transform_pandas:
transform_type = BatchFormat.PANDAS
else:
raise NotImplementedError(
"None of `_transform_numpy` or `_transform_pandas` "
f"are implemented for dataset format `{data_format}`."
)

return transform_type
if has_transform_numpy and has_transform_pandas:
return self.preferred_batch_format()
elif has_transform_numpy:
return BatchFormat.NUMPY
elif has_transform_pandas:
return BatchFormat.PANDAS
else:
raise NotImplementedError(
"None of `_transform_numpy` or `_transform_pandas` are implemented. "
"At least one of these transform functions must be implemented "
"for Preprocessor transforms."
)

def _transform(
self, dataset: Union["Dataset", "DatasetPipeline"]
) -> Union["Dataset", "DatasetPipeline"]:
# TODO(matt): Expose `batch_size` or similar configurability.
# The default may be too small for some datasets and too large for others.

dataset_format = dataset.dataset_format()
if dataset_format not in (BlockFormat.PANDAS, BlockFormat.ARROW):
raise ValueError(
f"Unsupported Dataset format: '{dataset_format}'. Only 'pandas' "
"and 'arrow' Dataset formats are supported."
)

transform_type = self._determine_transform_to_use(dataset_format)
transform_type = self._determine_transform_to_use()

# Our user-facing batch format should only be pandas or NumPy, other
# formats {arrow, simple} are internal.
Expand Down Expand Up @@ -318,20 +290,14 @@ def _transform_batch(self, data: "DataBatchType") -> "DataBatchType":
except ImportError:
pyarrow = None

if isinstance(data, pd.DataFrame):
data_format = BlockFormat.PANDAS
elif pyarrow is not None and isinstance(data, pyarrow.Table):
data_format = BlockFormat.ARROW
elif isinstance(data, (dict, np.ndarray)):
data_format = "numpy"
else:
raise NotImplementedError(
if not isinstance(data, (pd.DataFrame, pyarrow.Table, dict, np.ndarray)):
raise ValueError(
"`transform_batch` is currently only implemented for Pandas "
"DataFrames, pyarrow Tables, NumPy ndarray and dictionary of "
f"ndarray. Got {type(data)}."
)

transform_type = self._determine_transform_to_use(data_format)
transform_type = self._determine_transform_to_use()

if transform_type == BatchFormat.PANDAS:
return self._transform_pandas(convert_batch_type_to_pandas(data))
Expand All @@ -349,3 +315,16 @@ def _transform_numpy(
) -> Union["np.ndarray", Dict[str, "np.ndarray"]]:
"""Run the transformation on a data batch in a NumPy ndarray format."""
raise NotImplementedError()

@classmethod
@DeveloperAPI
def preferred_batch_format(cls) -> BatchFormat:
"""Batch format hint for upstream producers to try yielding best block format.
The preferred batch format to use if both `_transform_pandas` and
`_transform_numpy` are implemented. Defaults to Pandas.
Can be overriden by Preprocessor classes depending on which transform
path is the most optimal.
"""
return BatchFormat.PANDAS
6 changes: 3 additions & 3 deletions python/ray/data/preprocessors/batch_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from ray.air.util.data_batch_conversion import BatchFormat, BlockFormat
from ray.air.util.data_batch_conversion import BatchFormat
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -106,11 +106,11 @@ def _transform_numpy(
def _transform_pandas(self, df: "pandas.DataFrame") -> "pandas.DataFrame":
return self.fn(df)

def _determine_transform_to_use(self, data_format: BlockFormat):
def _determine_transform_to_use(self):
if self.batch_format:
return self.batch_format
else:
return super()._determine_transform_to_use(data_format)
return super()._determine_transform_to_use()

def _get_transform_config(self) -> Dict[str, Any]:
return {"batch_size": self.batch_size}
Expand Down
6 changes: 3 additions & 3 deletions python/ray/data/preprocessors/chain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Union
from ray.air.util.data_batch_conversion import BatchFormat, BlockFormat
from ray.air.util.data_batch_conversion import BatchFormat
from ray.data import Dataset, DatasetPipeline
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -99,9 +99,9 @@ def __repr__(self):
arguments = ", ".join(repr(preprocessor) for preprocessor in self.preprocessors)
return f"{self.__class__.__name__}({arguments})"

def _determine_transform_to_use(self, data_format: BlockFormat) -> BatchFormat:
def _determine_transform_to_use(self) -> BatchFormat:
# This is relevant for BatchPrediction.
# For Chain preprocessor, we picked the first one as entry point.
# TODO (jiaodong): We should revisit if our Chain preprocessor is
# still optimal with context of lazy execution.
return self.preprocessors[0]._determine_transform_to_use(data_format)
return self.preprocessors[0]._determine_transform_to_use()
4 changes: 4 additions & 0 deletions python/ray/data/preprocessors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np

from ray.air.util.data_batch_conversion import BatchFormat
from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -120,3 +121,6 @@ def transform_batch(batch: np.ndarray) -> np.ndarray:
outputs = transform_batch(np_data)

return outputs

def preferred_batch_format(cls) -> BatchFormat:
return BatchFormat.NUMPY
Loading

0 comments on commit f8f7374

Please sign in to comment.