Skip to content

Commit

Permalink
[Datasets] read_csv not filter out files by default (#29032)
Browse files Browse the repository at this point in the history
Currently read_csv filters out files without .csv extension when reading. This behavior seems to be surprising to users, and reported to be bad user experience in 3+ user reports (#26605). We should change to NOT filter files by default.

Verified Arrow (https://arrow.apache.org/docs/python/csv.html) and Spark (https://spark.apache.org/docs/latest/sql-data-sources-csv.html) does not filter out CSV files by default. I don't see a strong reason why we want to do it in a different way in Ray.

Added documentation in case users want to use partition_filter to filter out files, and gave an example to filter out files with .csv extension.

Also improve the error message when reading CSV file
  • Loading branch information
c21 authored Oct 7, 2022
1 parent 770bdf7 commit 92df1c1
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 30 deletions.
36 changes: 22 additions & 14 deletions python/ray/data/datasource/csv_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class CSVDatasource(FileBasedDatasource):
def _read_stream(
self, f: "pyarrow.NativeFile", path: str, **reader_args
) -> Iterator[Block]:
import pyarrow
import pyarrow as pa
from pyarrow import csv

read_options = reader_args.pop(
Expand All @@ -40,19 +40,27 @@ def _read_stream(
if hasattr(parse_options, "invalid_row_handler"):
parse_options.invalid_row_handler = parse_options.invalid_row_handler

reader = csv.open_csv(
f, read_options=read_options, parse_options=parse_options, **reader_args
)
schema = None
while True:
try:
batch = reader.read_next_batch()
table = pyarrow.Table.from_batches([batch], schema=schema)
if schema is None:
schema = table.schema
yield table
except StopIteration:
return
try:
reader = csv.open_csv(
f, read_options=read_options, parse_options=parse_options, **reader_args
)
schema = None
while True:
try:
batch = reader.read_next_batch()
table = pa.Table.from_batches([batch], schema=schema)
if schema is None:
schema = table.schema
yield table
except StopIteration:
return
except pa.lib.ArrowInvalid as e:
raise ValueError(
f"Failed to read CSV file: {path}. "
"Please check the CSV file has correct format, or filter out non-CSV "
"file with 'partition_filter' field. See read_csv() documentation for "
"more details."
) from e

def _write_block(
self,
Expand Down
24 changes: 15 additions & 9 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,7 @@ def read_csv(
ray_remote_args: Dict[str, Any] = None,
arrow_open_stream_args: Optional[Dict[str, Any]] = None,
meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(),
partition_filter: Optional[
PathPartitionFilter
] = CSVDatasource.file_extension_filter(),
partition_filter: Optional[PathPartitionFilter] = None,
partitioning: Partitioning = Partitioning("hive"),
**arrow_csv_args,
) -> Dataset[ArrowRow]:
Expand All @@ -597,15 +595,13 @@ def read_csv(
>>> ray.data.read_csv( # doctest: +SKIP
... ["s3://bucket/path1", "s3://bucket/path2"])
>>> # Read files that use a different delimiter. The partition_filter=None is needed here
>>> # because by default read_csv only reads .csv files. For more uses of ParseOptions see
>>> # Read files that use a different delimiter. For more uses of ParseOptions see
>>> # https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html # noqa: #501
>>> from pyarrow import csv
>>> parse_options = csv.ParseOptions(delimiter="\t")
>>> ray.data.read_csv( # doctest: +SKIP
... "example://iris.tsv",
... parse_options=parse_options,
... partition_filter=None)
... parse_options=parse_options)
>>> # Convert a date column with a custom format from a CSV file.
>>> # For more uses of ConvertOptions see
Expand All @@ -626,6 +622,15 @@ def read_csv(
>>> ds.take(1) # doctest: + SKIP
[{'order_number': 10107, 'quantity': 30, 'year': '2022', 'month': '09'}
By default, ``read_csv`` reads all files from file paths. If you want to filter
files by file extensions, set the ``partition_filter`` parameter.
>>> # Read only *.csv files from multiple directories.
>>> from ray.data.datasource import FileExtensionFilter
>>> ray.data.read_csv( # doctest: +SKIP
... ["s3://bucket/path1", "s3://bucket/path2"],
... partition_filter=FileExtensionFilter("csv"))
Args:
paths: A single file/directory path or a list of file/directory paths.
A list of paths can contain both files and directories.
Expand All @@ -639,8 +644,9 @@ def read_csv(
be able to resolve file metadata more quickly and/or accurately.
partition_filter: Path-based partition filter, if any. Can be used
with a custom callback to read only selected partitions of a dataset.
By default, this filters out any file paths whose file extension does not
match "*.csv*".
By default, this does not filter out any files.
If wishing to filter out all file paths except those whose file extension
matches e.g. "*.csv*", a ``FileExtensionFilter("csv")`` can be provided.
partitioning: A :class:`~ray.data.datasource.partitioning.Partitioning` object
that describes how paths are organized. By default, this function parses
`Hive-style partitions <https://athena.guide/articles/hive-style-partitioning/>`_.
Expand Down
60 changes: 53 additions & 7 deletions python/ray/data/tests/test_dataset_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
PathPartitionEncoder,
PathPartitionFilter,
)
from ray.data.datasource.file_based_datasource import _unwrap_protocol
from ray.data.datasource.file_based_datasource import (
FileExtensionFilter,
_unwrap_protocol,
)


def df_to_csv(dataframe, path, **kwargs):
Expand Down Expand Up @@ -196,7 +199,12 @@ def test_csv_read(ray_start_regular_shared, fs, data_path, endpoint_url):
storage_options=storage_options,
)

ds = ray.data.read_csv(path, filesystem=fs, partitioning=None)
ds = ray.data.read_csv(
path,
filesystem=fs,
partition_filter=FileExtensionFilter("csv"),
partitioning=None,
)
assert ds.num_blocks() == 2
df = pd.concat([df1, df2], ignore_index=True)
dsdf = ds.to_pandas()
Expand Down Expand Up @@ -642,7 +650,7 @@ def test_csv_read_with_column_type_specified(shutdown_only, tmp_path):

# Incorrect to parse scientific notation in int64 as PyArrow represents
# it as double.
with pytest.raises(pa.lib.ArrowInvalid):
with pytest.raises(ValueError):
ray.data.read_csv(
file_path,
convert_options=csv.ConvertOptions(
Expand All @@ -661,15 +669,53 @@ def test_csv_read_with_column_type_specified(shutdown_only, tmp_path):
assert ds.to_pandas().equals(expected_df)


def test_csv_read_filter_no_file(shutdown_only, tmp_path):
def test_csv_read_filter_non_csv_file(shutdown_only, tmp_path):
df = pd.DataFrame({"one": [1, 2, 3], "two": ["a", "b", "c"]})

# CSV file with .csv extension.
path1 = os.path.join(tmp_path, "test2.csv")
df.to_csv(path1, index=False)

# CSV file without .csv extension.
path2 = os.path.join(tmp_path, "test3")
df.to_csv(path2, index=False)

# Directory of CSV files.
ds = ray.data.read_csv(tmp_path)
assert ds.to_pandas().equals(pd.concat([df, df], ignore_index=True))

# Non-CSV file in Parquet format.
table = pa.Table.from_pandas(df)
path = os.path.join(str(tmp_path), "test.parquet")
pq.write_table(table, path)
path3 = os.path.join(tmp_path, "test1.parquet")
pq.write_table(table, path3)

# Single non-CSV file.
error_message = "Failed to read CSV file"
with pytest.raises(ValueError, match=error_message):
ray.data.read_csv(path3)

# Single non-CSV file with filter.
error_message = "No input files found to read"
with pytest.raises(ValueError, match=error_message):
ray.data.read_csv(path)
ray.data.read_csv(path3, partition_filter=FileExtensionFilter("csv"))

# Single CSV file without extension.
ds = ray.data.read_csv(path2)
assert ds.to_pandas().equals(df)

# Single CSV file without extension with filter.
error_message = "No input files found to read"
with pytest.raises(ValueError, match=error_message):
ray.data.read_csv(path2, partition_filter=FileExtensionFilter("csv"))

# Directory of CSV and non-CSV files.
error_message = "Failed to read CSV file"
with pytest.raises(ValueError, match=error_message):
ray.data.read_csv(tmp_path)

# Directory of CSV and non-CSV files with filter.
ds = ray.data.read_csv(tmp_path, partition_filter=FileExtensionFilter("csv"))
assert ds.to_pandas().equals(df)


@pytest.mark.skipif(
Expand Down

0 comments on commit 92df1c1

Please sign in to comment.