diff --git a/eta/core/storage.py b/eta/core/storage.py index 02d77005..3ec961b5 100644 --- a/eta/core/storage.py +++ b/eta/core/storage.py @@ -274,6 +274,17 @@ def list_files_in_folder(self, remote_dir, recursive=True): "subclass must implement list_files_in_folder()" ) + def list_subfolders(self, remote_dir): + """Returns a list of the subfolders in the given remote directory. + + Args: + remote_dir: the remote directory + + Returns: + a list of full paths to the subfolders in the folder + """ + raise NotImplementedError("subclass must implement list_subfolders()") + def upload_dir( self, local_dir, remote_dir, recursive=True, skip_failures=False ): @@ -639,6 +650,17 @@ def list_files_in_folder(self, storage_dir, recursive=True): storage_dir, abs_paths=True, recursive=recursive ) + def list_subfolders(self, storage_dir): + """Returns a list of the subfolders in the given storage directory. + + Args: + storage_dir: the storage directory + + Returns: + a list of full paths to the subfolders in the folder + """ + return etau.list_subdirs(storage_dir, abs_paths=True) + class _BotoCredentialsError(Exception): def __init__(self, message): @@ -990,6 +1012,32 @@ def list_files_in_folder( return paths_or_metadata + def list_subfolders(self, cloud_folder): + """Returns a list of sub "folders" in the given cloud "folder". + + Args: + cloud_folder: a cloud "folder" path + + Returns: + a list of full cloud paths for the subfolders in the folder + """ + bucket, folder_name = self._parse_path(cloud_folder) + if folder_name and not folder_name.endswith("/"): + folder_name += "/" + + prefix = self._get_prefix(cloud_folder) + bucket + "/" + paginator = self._client.get_paginator("list_objects_v2") + + # https://stackoverflow.com/q/14653694 + paths = set() + for page in paginator.paginate( + Bucket=bucket, Prefix=folder_name, Delimiter="/" + ).search("CommonPrefixes"): + if page is not None: + paths.add(page["Prefix"]) + + return [prefix + p for p in paths] + def generate_signed_url( self, cloud_path, method="GET", hours=24, content_type=None ): @@ -2130,13 +2178,41 @@ def list_files_in_folder( # Return paths for each file paths = [] - prefix = "gs://" + bucket_name + prefix = "gs://" + bucket_name + "/" for blob in blobs: if not blob.name.endswith("/"): - paths.append(prefix + "/" + blob.name) + paths.append(prefix + blob.name) return paths + def list_subfolders(self, cloud_folder): + """Returns a list of sub "folders" in the given "folder" in GCS. + + Args: + cloud_folder: a string like `gs:///` + + Returns: + a list of full cloud paths for the subfolders in the folder + """ + bucket_name, folder_name = self._parse_path(cloud_folder) + if folder_name and not folder_name.endswith("/"): + folder_name += "/" + + prefix = "gs://" + bucket_name + "/" + blobs = self._client.list_blobs( + bucket_name, + prefix=folder_name, + delimiter="/", + retry=self._retry, + ) + + # https://github.com/googleapis/google-cloud-python/issues/920 + paths = set() + for page in blobs.pages: + paths.update(page.prefixes) + + return [prefix + p for p in paths] + def generate_signed_url( self, cloud_path, method="GET", hours=24, content_type=None ): @@ -2912,6 +2988,34 @@ def list_files_in_folder( return paths + def list_subfolders(self, cloud_folder): + """Returns a list of sub "folders" in the given "folder" in Azure + Storage. + + Args: + cloud_folder: a string like + `https://.blob.core.windows.net//` + + Returns: + a list of full cloud paths for the subfolders in the folder + """ + container_name, folder_name = self._parse_path(cloud_folder) + if folder_name and not folder_name.endswith("/"): + folder_name += "/" + + prefix = self._get_prefix(cloud_folder) + container_name + "/" + blobs = self._list_blobs( + container_name, prefix=folder_name, recursive=False + ) + + # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-list-python#use-a-hierarchical-listing + paths = set() + for blob in blobs: + if blob.name.endswith("/"): + paths.add(blob.name) + + return [prefix + p for p in paths] + def generate_signed_url( self, cloud_path, method="GET", hours=24, content_type=None ): @@ -2993,15 +3097,14 @@ def _get_blob_client(self, cloud_path): return self._client.get_blob_client(container_name, blob_name) def _parse_path(self, cloud_path): - try: - client = azb.BlobClient.from_blob_url(cloud_path) - return client.container_name, client.blob_name - except ValueError as e: - try: - client = azb.ContainerClient.from_container_url(cloud_path) - return client.container_name, "" - except: - raise e + _cloud_path = self._strip_prefix(cloud_path) + + chunks = _cloud_path.split("/", 1) + + if len(chunks) != 2: + return chunks[0], "" + + return chunks[0], chunks[1] def _get_prefix(self, cloud_path): return _get_prefix(cloud_path, self._prefixes) diff --git a/eta/core/utils.py b/eta/core/utils.py index e7868fa5..12f06583 100644 --- a/eta/core/utils.py +++ b/eta/core/utils.py @@ -3844,7 +3844,7 @@ def get_glob_matches(glob_patt): Returns: a list of file paths that match `glob_patt` """ - return sorted(glob.glob(glob_patt)) + return sorted(p for p in glob.glob(glob_patt) if not os.path.isdir(p)) def parse_glob_pattern(glob_patt): diff --git a/setup.py b/setup.py index 2272595f..2e4d42f4 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ from wheel.bdist_wheel import bdist_wheel -VERSION = "0.11.0" +VERSION = "0.12.0" class BdistWheelCustom(bdist_wheel):