Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify UCObjectStore.list_objects to lists all files recursively #2959

Merged
merged 9 commits into from
Feb 3, 2024
34 changes: 23 additions & 11 deletions composer/utils/object_store/uc_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,6 @@ def get_object_size(self, object_name: str) -> int:
def list_objects(self, prefix: Optional[str]) -> List[str]:
"""List all objects in the object store with the given prefix.

.. note::

This function removes the directories from the returned list.

Args:
prefix (str): The prefix to search for.

Expand All @@ -234,14 +230,30 @@ def list_objects(self, prefix: Optional[str]) -> List[str]:

from databricks.sdk.core import DatabricksError
try:
data = json.dumps({'path': self._get_object_path(prefix)})
# NOTE: This API is in preview and should not be directly used outside of this instance
resp = self.client.api_client.do(method='GET',
path=self._UC_VOLUME_LIST_API_ENDPOINT,
data=data,
headers={'Source': 'mosaicml/composer'})
assert isinstance(resp, dict)
return [f['path'] for f in resp.get('files', []) if not f['is_dir']]
logging.warn('UCObjectStore.list_objects is experimental.')
max_recursion_depth = 4

def get_uc_files(dir_path: str, recursion_depth: int = 0) -> list[str]:
if recursion_depth == max_recursion_depth:
raise Exception(
f'Objects at {dir_path} cannot be downloaded. Please reduce the' +
' level of folder nesting from {prefix} in UC Volumes to under {max_recursion_depth}.')
resp = self.client.api_client.do(method='GET',
path=self._UC_VOLUME_LIST_API_ENDPOINT,
data=json.dumps({'path': self._get_object_path(dir_path)}),
headers={'Source': 'mosaicml/composer'})
assert isinstance(resp, dict)
files = []
for f in resp.get('files', []):
fpath = f['path']
if f['is_dir']:
files.extend(get_uc_files(fpath, recursion_depth=recursion_depth + 1))
else:
files.append(fpath)
return files

return get_uc_files(prefix)
except DatabricksError as e:
_wrap_errors(self.get_uri(prefix), e)
return []
Loading