Skip to content

Commit

Permalink
[air] Move storage handling to pyarrow.fs.FileSystem (#23370)
Browse files Browse the repository at this point in the history
  • Loading branch information
krfricke authored Apr 13, 2022
1 parent 65d9a41 commit e3bd598
Show file tree
Hide file tree
Showing 22 changed files with 437 additions and 396 deletions.
2 changes: 1 addition & 1 deletion python/ray/data/impl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

logger = logging.getLogger(__name__)

MIN_PYARROW_VERSION = (4, 0, 1)
MIN_PYARROW_VERSION = (6, 0, 1)
_VERSION_VALIDATED = False


Expand Down
18 changes: 13 additions & 5 deletions python/ray/data/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
def aws_credentials():
import os

# Credentials dict that can be passed as kwargs to pa.fs.S3FileSystem
credentials = dict(
access_key="testing", secret_key="testing", session_token="testing"
)

old_env = os.environ
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_ACCESS_KEY_ID"] = credentials["access_key"]
os.environ["AWS_SECRET_ACCESS_KEY"] = credentials["secret_key"]
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
yield
os.environ["AWS_SESSION_TOKEN"] = credentials["session_token"]

yield credentials
os.environ = old_env


Expand Down Expand Up @@ -56,7 +62,9 @@ def s3_fs_with_space(aws_credentials, s3_server, s3_path_with_space):
def _s3_fs(aws_credentials, s3_server, s3_path):
import urllib.parse

fs = pa.fs.S3FileSystem(region="us-west-2", endpoint_override=s3_server)
fs = pa.fs.S3FileSystem(
region="us-west-2", endpoint_override=s3_server, **aws_credentials
)
if s3_path.startswith("s3://"):
s3_path = s3_path[len("s3://") :]
s3_path = urllib.parse.quote(s3_path)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/data/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,7 +990,7 @@ def test_convert_types(ray_start_regular_shared):

arrow_ds = ray.data.range_arrow(1)
assert arrow_ds.map(lambda x: "plain_{}".format(x["value"])).take() == ["plain_0"]
assert arrow_ds.map(lambda x: {"a": (x["value"],)}).take() == [{"a": (0,)}]
assert arrow_ds.map(lambda x: {"a": (x["value"],)}).take() == [{"a": [0]}]


def test_from_items(ray_start_regular_shared):
Expand Down
32 changes: 18 additions & 14 deletions python/ray/ml/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import io
import os
import shutil
import tarfile
import tempfile

import os
from typing import Optional, Union, Tuple

import ray
from ray import cloudpickle as pickle
from ray.util.annotations import DeveloperAPI
from ray.util.ml_utils.cloud import (
upload_to_bucket,
is_cloud_target,
download_from_bucket,
from ray.ml.utils.remote_storage import (
upload_to_uri,
is_non_local_path_uri,
download_from_uri,
fs_hint,
)

from ray.util.annotations import DeveloperAPI

_DICT_CHECKPOINT_FILE_NAME = "dict_checkpoint.pkl"
_FS_CHECKPOINT_KEY = "fs_checkpoint"
Expand Down Expand Up @@ -331,7 +330,7 @@ def to_directory(self, path: Optional[str] = None) -> str:
shutil.copytree(local_path, path)
elif external_path:
# If this exists on external storage (e.g. cloud), download
download_from_bucket(bucket=external_path, local_path=path)
download_from_uri(uri=external_path, local_path=path)
else:
raise RuntimeError(
f"No valid location found for checkpoint {self}: {self._uri}"
Expand All @@ -358,7 +357,7 @@ def from_uri(cls, uri: str) -> "Checkpoint":
def to_uri(self, uri: str) -> str:
"""Write checkpoint data to location URI (e.g. cloud storage).
ARgs:
Args:
uri (str): Target location URI to write data to.
Returns:
Expand All @@ -368,7 +367,12 @@ def to_uri(self, uri: str) -> str:
local_path = uri[7:]
return self.to_directory(local_path)

assert is_cloud_target(uri)
if not is_non_local_path_uri(uri):
raise RuntimeError(
f"Cannot upload checkpoint to URI: Provided URI "
f"does not belong to a registered storage provider: `{uri}`. "
f"Hint: {fs_hint(uri)}"
)

cleanup = False

Expand All @@ -377,7 +381,7 @@ def to_uri(self, uri: str) -> str:
cleanup = True
local_path = self.to_directory()

upload_to_bucket(bucket=uri, local_path=local_path)
upload_to_uri(local_path=local_path, uri=uri)

if cleanup:
shutil.rmtree(local_path)
Expand Down Expand Up @@ -429,7 +433,7 @@ def __setstate__(self, state):

def _get_local_path(path: Optional[str]) -> Optional[str]:
"""Check if path is a local path. Otherwise return None."""
if path is None or is_cloud_target(path):
if path is None or is_non_local_path_uri(path):
return None
if path.startswith("file://"):
path = path[7:]
Expand All @@ -440,7 +444,7 @@ def _get_local_path(path: Optional[str]) -> Optional[str]:

def _get_external_path(path: Optional[str]) -> Optional[str]:
"""Check if path is an external path. Otherwise return None."""
if not isinstance(path, str) or not is_cloud_target(path):
if not isinstance(path, str) or not is_non_local_path_uri(path):
return None
return path

Expand Down
102 changes: 78 additions & 24 deletions python/ray/ml/tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,40 @@
import tempfile
import unittest
from typing import Any
from unittest.mock import patch

import ray
from ray.ml.checkpoint import Checkpoint
from ray.ml.tests.utils import mock_s3_sync
from ray.ml.utils.remote_storage import delete_at_uri, _ensure_directory


class CheckpointsConversionTest(unittest.TestCase):
def setUp(self):
self.tmpdir = os.path.realpath(tempfile.mkdtemp())
self.tmpdir_pa = os.path.realpath(tempfile.mkdtemp())

self.checkpoint_dict_data = {"metric": 5, "step": 4}
self.checkpoint_dir_data = {"metric": 2, "step": 6}

self.cloud_uri = "s3://invalid"
self.local_mock_cloud_path = os.path.realpath(tempfile.mkdtemp())
self.mock_s3 = mock_s3_sync(self.local_mock_cloud_path)
# We test two different in-memory filesystems as "cloud" providers,
# one for fsspec and one for pyarrow.fs

# fsspec URI
self.cloud_uri = "memory:///cloud/bucket"
# pyarrow URI
self.cloud_uri_pa = "mock://cloud/bucket/"

self.checkpoint_dir = os.path.join(self.tmpdir, "existing_checkpoint")
os.mkdir(self.checkpoint_dir, 0o755)
with open(os.path.join(self.checkpoint_dir, "test_data.pkl"), "wb") as fp:
pickle.dump(self.checkpoint_dir_data, fp)

self.old_dir = os.getcwd()
os.chdir(self.tmpdir)

def tearDown(self):
os.chdir(self.old_dir)
shutil.rmtree(self.tmpdir)
shutil.rmtree(self.local_mock_cloud_path)
shutil.rmtree(self.tmpdir_pa)

def _prepare_dict_checkpoint(self) -> Checkpoint:
# Create checkpoint from dict
Expand Down Expand Up @@ -111,17 +117,35 @@ def test_dict_checkpoint_uri(self):
"""Test conversion from dict to cloud checkpoint and back."""
checkpoint = self._prepare_dict_checkpoint()

with patch("subprocess.check_call", self.mock_s3):
# Convert into dict checkpoint
location = checkpoint.to_uri(self.cloud_uri)
self.assertIsInstance(location, str)
self.assertIn("s3://", location)
# Convert into dict checkpoint
location = checkpoint.to_uri(self.cloud_uri)
self.assertIsInstance(location, str)
self.assertIn("memory://", location)

# Create from dict
checkpoint = Checkpoint.from_uri(location)
self.assertTrue(checkpoint._uri)

self._assert_dict_checkpoint(checkpoint)

def test_dict_checkpoint_uri_pa(self):
"""Test conversion from dict to cloud checkpoint and back."""
checkpoint = self._prepare_dict_checkpoint()

# Clean up mock bucket
delete_at_uri(self.cloud_uri_pa)
_ensure_directory(self.cloud_uri_pa)

# Convert into dict checkpoint
location = checkpoint.to_uri(self.cloud_uri_pa)
self.assertIsInstance(location, str)
self.assertIn("mock://", location)

# Create from dict
checkpoint = Checkpoint.from_uri(location)
self.assertTrue(checkpoint._uri)
# Create from dict
checkpoint = Checkpoint.from_uri(location)
self.assertTrue(checkpoint._uri)

self._assert_dict_checkpoint(checkpoint)
self._assert_dict_checkpoint(checkpoint)

def _prepare_fs_checkpoint(self) -> Checkpoint:
# Create checkpoint from fs
Expand Down Expand Up @@ -204,17 +228,47 @@ def test_fs_checkpoint_uri(self):
"""Test conversion from fs to cloud checkpoint and back."""
checkpoint = self._prepare_fs_checkpoint()

with patch("subprocess.check_call", self.mock_s3):
# Convert into dict checkpoint
location = checkpoint.to_uri(self.cloud_uri)
self.assertIsInstance(location, str)
self.assertIn("s3://", location)
# Convert into dict checkpoint
location = checkpoint.to_uri(self.cloud_uri)
self.assertIsInstance(location, str)
self.assertIn("memory://", location)

# Create from dict
checkpoint = Checkpoint.from_uri(location)
self.assertTrue(checkpoint._uri)

self._assert_fs_checkpoint(checkpoint)

def test_fs_checkpoint_uri_pa(self):
"""Test conversion from fs to cloud checkpoint and back."""
checkpoint = self._prepare_fs_checkpoint()

# Clean up mock bucket
delete_at_uri(self.cloud_uri_pa)
_ensure_directory(self.cloud_uri_pa)

# Convert into dict checkpoint
location = checkpoint.to_uri(self.cloud_uri_pa)
self.assertIsInstance(location, str)
self.assertIn("mock://", location)

# Create from dict
checkpoint = Checkpoint.from_uri(location)
self.assertTrue(checkpoint._uri)

self._assert_fs_checkpoint(checkpoint)

def test_fs_delete_at_uri(self):
"""Test that clear bucket utility works"""
checkpoint = self._prepare_fs_checkpoint()

# Create from dict
checkpoint = Checkpoint.from_uri(location)
self.assertTrue(checkpoint._uri)
# Convert into dict checkpoint
location = checkpoint.to_uri(self.cloud_uri)
delete_at_uri(location)

self._assert_fs_checkpoint(checkpoint)
checkpoint = Checkpoint.from_uri(location)
with self.assertRaises(FileNotFoundError):
checkpoint.to_directory()


class CheckpointsSerdeTest(unittest.TestCase):
Expand Down
30 changes: 0 additions & 30 deletions python/ray/ml/tests/utils.py

This file was deleted.

Loading

0 comments on commit e3bd598

Please sign in to comment.