From df898702fe1b939b55352213ed75646e0e8cd92e Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 1 Jun 2023 18:17:29 +0100 Subject: [PATCH] [air] Pass on KMS-related kwargs for s3fs (#35938) We currently only parse and pass limited selection of options to s3fs. One recent request was related to passing KMS settings. This PR extends the s3 uri string to allow configuration of signature version, sse, sse key ID, and ACLs in s3 URIs if s3fs is used. This PR also changes the fs caching logic, which is a requirement for options to be parsed again, e.g. when a key ID is changed in subsequent calls. FS cache keys now include the query string, and cache items are stale after 5 minutes and re-created. As a side-effect, this should fix any problems that come with cached filesystems, e.g. expiring credentials. Signed-off-by: Kai Fricke Signed-off-by: e428265 --- python/ray/air/_internal/remote_storage.py | 84 ++++++++++++++++----- python/ray/air/tests/test_remote_storage.py | 81 ++++++++++++++++++++ 2 files changed, 147 insertions(+), 18 deletions(-) diff --git a/python/ray/air/_internal/remote_storage.py b/python/ray/air/_internal/remote_storage.py index 13aaf91ab919c..1167076a297e8 100644 --- a/python/ray/air/_internal/remote_storage.py +++ b/python/ray/air/_internal/remote_storage.py @@ -2,6 +2,7 @@ import os import pathlib import sys +import time import urllib.parse from pathlib import Path from pkg_resources import packaging @@ -29,6 +30,10 @@ from ray import logger +# Re-create fs objects after this amount of seconds +_CACHE_VALIDITY_S = 300 + + class _ExcludingLocalFilesystem(LocalFileSystem): """LocalFileSystem wrapper to exclude files according to patterns. @@ -133,8 +138,26 @@ def is_non_local_path_uri(uri: str) -> bool: return False -# Cache fs objects -_cached_fs = {} +# Cache fs objects. Map from cache_key --> timestamp, fs +_cached_fs: Dict[tuple, Tuple[float, pyarrow.fs.FileSystem]] = {} + + +def _get_cache(cache_key: tuple) -> Optional[pyarrow.fs.FileSystem]: + ts, fs = _cached_fs.get(cache_key, (0, None)) + if not fs: + return None + + now = time.monotonic() + if now - ts >= _CACHE_VALIDITY_S: + _cached_fs.pop(cache_key) + return None + + return fs + + +def _put_cache(cache_key: tuple, fs: pyarrow.fs.FileSystem): + now = time.monotonic() + _cached_fs[cache_key] = (now, fs) def _get_network_mounts() -> List[str]: @@ -182,6 +205,18 @@ def _is_local_windows_path(path: str) -> bool: return False +def _translate_options( + option_map: Dict[str, str], options: Dict[str, List[str]] +) -> Dict[str, str]: + """Given mapping of old_name -> new_name in option_map, rename keys.""" + translated = {} + for opt, target in option_map.items(): + if opt in options: + translated[target] = options[opt][0] + + return translated + + def _translate_s3_options(options: Dict[str, List[str]]) -> Dict[str, Any]: """Translate pyarrow s3 query options into s3fs ``storage_kwargs``. @@ -199,22 +234,38 @@ def _translate_s3_options(options: Dict[str, List[str]]) -> Dict[str, Any]: """ # Map from s3 query keys --> botocore client arguments + # client_kwargs option_map = { "endpoint_override": "endpoint_url", "region": "region_name", "access_key": "aws_access_key_id", "secret_key": "aws_secret_access_key", } + client_kwargs = _translate_options(option_map, options) - client_kwargs = {} - for opt, target in option_map.items(): - if opt in options: - client_kwargs[target] = options[opt][0] + # config_kwargs + option_map = { + "signature_version": "signature_version", + } + config_kwargs = _translate_options(option_map, options) + + # s3_additional_kwargs + option_map = { + "ServerSideEncryption": "ServerSideEncryption", + "SSEKMSKeyId": "SSEKMSKeyId", + "GrantFullControl": "GrantFullControl", + } + s3_additional_kwargs = _translate_options(option_map, options) # s3fs directory cache does not work correctly, so we pass # `use_listings_cache` to disable it. See https://github.com/fsspec/s3fs/issues/657 # We should keep this for s3fs versions <= 2023.4.0. - return {"client_kwargs": client_kwargs, "use_listings_cache": False} + return { + "use_listings_cache": False, + "client_kwargs": client_kwargs, + "config_kwargs": config_kwargs, + "s3_additional_kwargs": s3_additional_kwargs, + } def _translate_gcs_options(options: Dict[str, List[str]]) -> Dict[str, Any]: @@ -234,10 +285,7 @@ def _translate_gcs_options(options: Dict[str, List[str]]) -> Dict[str, Any]: "endpoint_override": "endpoint_url", } - storage_kwargs = {} - for opt, target in option_map.items(): - if opt in options: - storage_kwargs[target] = options[opt][0] + storage_kwargs = _translate_options(option_map, options) return storage_kwargs @@ -281,9 +329,9 @@ def _get_fsspec_fs_and_path(uri: str) -> Optional["pyarrow.fs.FileSystem"]: parsed = urllib.parse.urlparse(uri) storage_kwargs = {} - if parsed.scheme in ["s3", "s3a"] and parsed.query: + if parsed.scheme in ["s3", "s3a"]: storage_kwargs = _translate_s3_options(urllib.parse.parse_qs(parsed.query)) - elif parsed.scheme in ["gs", "gcs"] and parsed.query: + elif parsed.scheme in ["gs", "gcs"]: if not _has_compatible_gcsfs_version(): # If gcsfs is incompatible, fallback to pyarrow.fs. return None @@ -329,17 +377,17 @@ def get_fs_and_path( else: path = parsed.netloc + parsed.path - cache_key = (parsed.scheme, parsed.netloc) + cache_key = (parsed.scheme, parsed.netloc, parsed.query) - if cache_key in _cached_fs: - fs = _cached_fs[cache_key] + fs = _get_cache(cache_key) + if fs: return fs, path # Prefer fsspec over native pyarrow. if fsspec: fs = _get_fsspec_fs_and_path(uri) if fs: - _cached_fs[cache_key] = fs + _put_cache(cache_key, fs) return fs, path # In case of hdfs filesystem, if uri does not have the netloc part below, it will @@ -355,7 +403,7 @@ def get_fs_and_path( # If no fsspec filesystem was found, use pyarrow native filesystem. try: fs, path = pyarrow.fs.FileSystem.from_uri(uri) - _cached_fs[cache_key] = fs + _put_cache(cache_key, fs) return fs, path except (pyarrow.lib.ArrowInvalid, pyarrow.lib.ArrowNotImplementedError): # Raised when URI not recognized diff --git a/python/ray/air/tests/test_remote_storage.py b/python/ray/air/tests/test_remote_storage.py index 99adcd8939f8e..ea499fddd6e9f 100644 --- a/python/ray/air/tests/test_remote_storage.py +++ b/python/ray/air/tests/test_remote_storage.py @@ -5,15 +5,20 @@ import pytest import shutil import tempfile +import urllib.parse from ray.air._internal.remote_storage import ( upload_to_uri, download_from_uri, get_fs_and_path, _is_network_mount, + _translate_s3_options, + _CACHE_VALIDITY_S, ) from ray.tune.utils.file_transfer import _get_recursive_files_and_stats +from freezegun import freeze_time + @pytest.fixture def temp_data_dirs(): @@ -235,6 +240,82 @@ def test_is_network_mount(tmp_path, monkeypatch): assert not _is_network_mount("") # cwd +def test_resolve_aws_kwargs(): + def _uri_to_opt(uri: str): + parsed = urllib.parse.urlparse(uri) + return urllib.parse.parse_qs(parsed.query) + + # client_kwargs + assert ( + _translate_s3_options(_uri_to_opt("s3://some/where?endpoint_override=EP"))[ + "client_kwargs" + ]["endpoint_url"] + == "EP" + ) + + # config_kwargs + assert ( + _translate_s3_options(_uri_to_opt("s3://some/where?signature_version=abc"))[ + "config_kwargs" + ]["signature_version"] + == "abc" + ) + + # s3_additional_kwargs + assert ( + _translate_s3_options(_uri_to_opt("s3://some/where?SSEKMSKeyId=abc"))[ + "s3_additional_kwargs" + ]["SSEKMSKeyId"] + == "abc" + ) + + # no kwargs + assert ( + _translate_s3_options(_uri_to_opt("s3://some/where"))["s3_additional_kwargs"] + == {} + ) + + +def test_cache_time_eviction(): + """We use a time-based cache for filesystem objects. + + This tests asserts that the cache is evicted after _CACHE_VALIDITY_S + seconds. + """ + with freeze_time() as frozen: + fs, path = get_fs_and_path("s3://some/where") + fs2, path = get_fs_and_path("s3://some/where") + + assert id(fs) == id(fs2) + + frozen.tick(_CACHE_VALIDITY_S - 10) + + # Cache not expired yet + fs2, path = get_fs_and_path("s3://some/where") + assert id(fs) == id(fs2) + + frozen.tick(10) + + # Cache expired + fs2, path = get_fs_and_path("s3://some/where") + assert id(fs) != id(fs2) + + +def test_cache_uri_query(): + """We cache fs objects, but different query parameters should have different + cached objects.""" + fs, path = get_fs_and_path("s3://some/where?only=we") + fs2, path = get_fs_and_path("s3://some/where?only=we") + + # Same query parameters, so same object + assert id(fs) == id(fs2) + + fs3, path = get_fs_and_path("s3://some/where?we=know") + + # Different query parameters, so different object + assert id(fs) != id(fs3) + + if __name__ == "__main__": import sys