Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VR-4768: Implement __add__ for dataset blobs #938

Merged
merged 4 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions client/verta/tests/test_versioning/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,45 @@ def test_list_paths(self):
dataset = verta.dataset.S3("s3://{}".format(bucket))
assert set(dataset.list_paths()) == expected_paths

def test_concat(self):
dataset1 = verta.dataset.S3("s3://verta-starter/")
dataset2 = verta.dataset.S3("s3://verta-versioned-bucket/")
components = dataset1.list_components() + dataset2.list_components()
components = list(sorted(components, key=lambda component: component.path))

dataset = dataset1 + dataset2
assert dataset.list_components() == components

# commutative
dataset = dataset2 + dataset1
assert dataset.list_components() == components

# assignment
dataset1 += dataset2
assert dataset1.list_components() == components

def test_concat_intersect_error(self):
dataset1 = verta.dataset.S3("s3://verta-starter/")
dataset2 = verta.dataset.S3("s3://verta-starter/census-test.csv")

with pytest.raises(ValueError):
dataset1 + dataset2 # pylint: disable=pointless-statement

# commutative
with pytest.raises(ValueError):
dataset2 + dataset1 # pylint: disable=pointless-statement

# assignment
with pytest.raises(ValueError):
dataset1 += dataset2

def test_concat_type_mismatch_error(self):
dataset1 = verta.dataset.S3("s3://verta-starter/")
dataset2 = verta.dataset.Path("modelapi_hypothesis/")

with pytest.raises(TypeError):
dataset1 + dataset2 # pylint: disable=pointless-statement


class TestPath:
def test_dirpath(self):
Expand Down Expand Up @@ -300,6 +339,66 @@ def test_invalid_base_path_error(self, paths, base_path):
with pytest.raises(ValueError):
verta.dataset.Path(paths, base_path)

def test_concat(self):
dataset1 = verta.dataset.Path("modelapi_hypothesis/")
dataset2 = verta.dataset.Path("test_versioning/")
components = dataset1.list_components() + dataset2.list_components()
components = list(sorted(components, key=lambda component: component.path))

dataset = dataset1 + dataset2
assert dataset.list_components() == components

# commutative
dataset = dataset2 + dataset1
assert dataset.list_components() == components

# assignment
dataset1 += dataset2
assert dataset1.list_components() == components

def test_concat_intersect_error(self):
dataset1 = verta.dataset.Path("test_versioning/")
dataset2 = verta.dataset.Path("test_versioning/test_dataset.py")

with pytest.raises(ValueError):
dataset1 + dataset2 # pylint: disable=pointless-statement

# commutative
with pytest.raises(ValueError):
dataset2 + dataset1 # pylint: disable=pointless-statement

# assignment
with pytest.raises(ValueError):
dataset1 += dataset2

def test_concat_base_path(self):
dataset1 = verta.dataset.Path(
"modelapi_hypothesis/",
base_path="modelapi_hypothesis/",
)
dataset2 = verta.dataset.Path(
"test_versioning/",
base_path="test_versioning/",
)
components = dataset1.list_components() + dataset2.list_components()
components = list(sorted(components, key=lambda component: component.path))

dataset = dataset1 + dataset2
assert dataset.list_components() == components

def test_concat_base_path_intersect_error(self):
dataset1 = verta.dataset.Path(
"./__init__.py",
base_path=".",
)
dataset2 = verta.dataset.Path(
"test_versioning/__init__.py",
base_path="test_versioning",
)

with pytest.raises(ValueError):
dataset1 + dataset2 # pylint: disable=pointless-statement


@pytest.mark.usefixtures("with_boto3", "in_tempdir")
class TestS3ManagedVersioning:
Expand Down Expand Up @@ -437,6 +536,57 @@ def test_download_all(self, commit):
assert os.path.isdir(dirpath)
assert_dirs_match(dirpath, reference_dir)

def test_concat(self, commit):
s3 = pytest.importorskip("boto3").client('s3')

bucket1 = "verta-starter"
key1 = "models/model.pkl"
bucket2 = "verta-versioned-bucket"
key2 = "tiny-files/tiny2.bin"

# create dir for reference files
reference_dir = "reference"
filepath1 = os.path.join(reference_dir, bucket1, key1)
pathlib2.Path(filepath1).parent.mkdir(parents=True, exist_ok=True)
filepath2 = os.path.join(reference_dir, bucket2, key2)
pathlib2.Path(filepath2).parent.mkdir(parents=True, exist_ok=True)

# download files directly from S3 for reference
s3.download_file(bucket1, key1, filepath1)
s3.download_file(bucket2, key2, filepath2)

# create and concatenate datasets
dataset1 = verta.dataset.S3(
"s3://{}/{}".format(bucket1, key1),
enable_mdb_versioning=True,
)
dataset2 = verta.dataset.S3(
"s3://{}/{}".format(bucket2, key2),
enable_mdb_versioning=True,
)
dataset = dataset1 + dataset2

blob_path = "data"
commit.update(blob_path, dataset)
commit.save("Version data.")
dataset = commit.get(blob_path)

dirpath = dataset.download()
assert_dirs_match(dirpath, reference_dir)

def test_concat_arg_mismatch_error(self):
dataset1 = verta.dataset.S3(
"s3://verta-starter/",
enable_mdb_versioning=True,
)
dataset2 = verta.dataset.S3(
"s3://verta-versioned-bucket/",
enable_mdb_versioning=False,
)

with pytest.raises(ValueError):
dataset1 + dataset2 # pylint: disable=pointless-statement


@pytest.mark.usefixtures("in_tempdir")
class TestPathManagedVersioning:
Expand Down Expand Up @@ -650,3 +800,51 @@ def test_base_path(self, commit):
dirpath = dataset.download()
assert os.path.abspath(dirpath) != os.path.abspath(reference_dir)
assert_dirs_match(dirpath, reference_dir)

def test_concat(self, commit):
reference_dir = "tiny-files/"
os.mkdir(reference_dir)
# two .file files in tiny-files/
for filename in ["tiny{}.file".format(i) for i in range(2)]:
with open(os.path.join(reference_dir, filename), 'wb') as f:
f.write(os.urandom(2**16))

# create and concatenate datasets
dataset1 = verta.dataset.Path(
"tiny-files/tiny0.file",
enable_mdb_versioning=True,
)
dataset2 = verta.dataset.Path(
"tiny-files/tiny1.file",
enable_mdb_versioning=True,
)
dataset = dataset1 + dataset2

blob_path = "data"
commit.update(blob_path, dataset)
commit.save("Version data.")
dataset = commit.get(blob_path)

dirpath = dataset.download()
dirpath = os.path.join(dirpath, reference_dir) # "tiny-files/" nested in new dir
assert_dirs_match(dirpath, reference_dir)

def test_concat_arg_mismatch_error(self):
reference_dir = "tiny-files/"
os.mkdir(reference_dir)
# two .file files in tiny-files/
for filename in ["tiny{}.file".format(i) for i in range(2)]:
with open(os.path.join(reference_dir, filename), 'wb') as f:
f.write(os.urandom(2**16))

dataset1 = verta.dataset.Path(
"tiny-files/tiny0.file",
enable_mdb_versioning=True,
)
dataset2 = verta.dataset.Path(
"tiny-files/tiny1.file",
enable_mdb_versioning=False,
)

with pytest.raises(ValueError):
dataset1 + dataset2 # pylint: disable=pointless-statement
21 changes: 20 additions & 1 deletion client/verta/verta/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class _Dataset(blob.Blob):
Base class for dataset versioning. Not for human consumption.

"""
def __init__(self, enable_mdb_versioning=False):
def __init__(self, paths=None, enable_mdb_versioning=False):
super(_Dataset, self).__init__()

self._components_map = dict() # paths to Component objects
Expand All @@ -46,6 +46,25 @@ def __repr__(self):

return "\n ".join(lines)

def __add__(self, other):
if not isinstance(other, type(self)):
return NotImplemented

self_keys = set(self._components_map.keys())
other_keys = set(other._components_map.keys())
intersection = list(self_keys & other_keys)
if intersection:
raise ValueError("datasets contain overlapping paths: {}".format(intersection))

if self._mdb_versioned != other._mdb_versioned:
raise ValueError("datasets must have same value for `enable_mdb_versioning`")

new = self.__class__(paths=[], enable_mdb_versioning=self._mdb_versioned)
new._components_map.update(self._components_map)
new._components_map.update(other._components_map)

return new

@abc.abstractmethod
def _prepare_components_to_upload(self):
pass
Expand Down