diff --git a/client/verta/tests/test_versioning/test_dataset.py b/client/verta/tests/test_versioning/test_dataset.py index ac6bfc3aac..0879b89969 100644 --- a/client/verta/tests/test_versioning/test_dataset.py +++ b/client/verta/tests/test_versioning/test_dataset.py @@ -201,6 +201,45 @@ def test_list_paths(self): dataset = verta.dataset.S3("s3://{}".format(bucket)) assert set(dataset.list_paths()) == expected_paths + def test_concat(self): + dataset1 = verta.dataset.S3("s3://verta-starter/") + dataset2 = verta.dataset.S3("s3://verta-versioned-bucket/") + components = dataset1.list_components() + dataset2.list_components() + components = list(sorted(components, key=lambda component: component.path)) + + dataset = dataset1 + dataset2 + assert dataset.list_components() == components + + # commutative + dataset = dataset2 + dataset1 + assert dataset.list_components() == components + + # assignment + dataset1 += dataset2 + assert dataset1.list_components() == components + + def test_concat_intersect_error(self): + dataset1 = verta.dataset.S3("s3://verta-starter/") + dataset2 = verta.dataset.S3("s3://verta-starter/census-test.csv") + + with pytest.raises(ValueError): + dataset1 + dataset2 # pylint: disable=pointless-statement + + # commutative + with pytest.raises(ValueError): + dataset2 + dataset1 # pylint: disable=pointless-statement + + # assignment + with pytest.raises(ValueError): + dataset1 += dataset2 + + def test_concat_type_mismatch_error(self): + dataset1 = verta.dataset.S3("s3://verta-starter/") + dataset2 = verta.dataset.Path("modelapi_hypothesis/") + + with pytest.raises(TypeError): + dataset1 + dataset2 # pylint: disable=pointless-statement + class TestPath: def test_dirpath(self): @@ -300,6 +339,66 @@ def test_invalid_base_path_error(self, paths, base_path): with pytest.raises(ValueError): verta.dataset.Path(paths, base_path) + def test_concat(self): + dataset1 = verta.dataset.Path("modelapi_hypothesis/") + dataset2 = verta.dataset.Path("test_versioning/") + components = dataset1.list_components() + dataset2.list_components() + components = list(sorted(components, key=lambda component: component.path)) + + dataset = dataset1 + dataset2 + assert dataset.list_components() == components + + # commutative + dataset = dataset2 + dataset1 + assert dataset.list_components() == components + + # assignment + dataset1 += dataset2 + assert dataset1.list_components() == components + + def test_concat_intersect_error(self): + dataset1 = verta.dataset.Path("test_versioning/") + dataset2 = verta.dataset.Path("test_versioning/test_dataset.py") + + with pytest.raises(ValueError): + dataset1 + dataset2 # pylint: disable=pointless-statement + + # commutative + with pytest.raises(ValueError): + dataset2 + dataset1 # pylint: disable=pointless-statement + + # assignment + with pytest.raises(ValueError): + dataset1 += dataset2 + + def test_concat_base_path(self): + dataset1 = verta.dataset.Path( + "modelapi_hypothesis/", + base_path="modelapi_hypothesis/", + ) + dataset2 = verta.dataset.Path( + "test_versioning/", + base_path="test_versioning/", + ) + components = dataset1.list_components() + dataset2.list_components() + components = list(sorted(components, key=lambda component: component.path)) + + dataset = dataset1 + dataset2 + assert dataset.list_components() == components + + def test_concat_base_path_intersect_error(self): + dataset1 = verta.dataset.Path( + "./__init__.py", + base_path=".", + ) + dataset2 = verta.dataset.Path( + "test_versioning/__init__.py", + base_path="test_versioning", + ) + + with pytest.raises(ValueError): + dataset1 + dataset2 # pylint: disable=pointless-statement + @pytest.mark.usefixtures("with_boto3", "in_tempdir") class TestS3ManagedVersioning: @@ -437,6 +536,57 @@ def test_download_all(self, commit): assert os.path.isdir(dirpath) assert_dirs_match(dirpath, reference_dir) + def test_concat(self, commit): + s3 = pytest.importorskip("boto3").client('s3') + + bucket1 = "verta-starter" + key1 = "models/model.pkl" + bucket2 = "verta-versioned-bucket" + key2 = "tiny-files/tiny2.bin" + + # create dir for reference files + reference_dir = "reference" + filepath1 = os.path.join(reference_dir, bucket1, key1) + pathlib2.Path(filepath1).parent.mkdir(parents=True, exist_ok=True) + filepath2 = os.path.join(reference_dir, bucket2, key2) + pathlib2.Path(filepath2).parent.mkdir(parents=True, exist_ok=True) + + # download files directly from S3 for reference + s3.download_file(bucket1, key1, filepath1) + s3.download_file(bucket2, key2, filepath2) + + # create and concatenate datasets + dataset1 = verta.dataset.S3( + "s3://{}/{}".format(bucket1, key1), + enable_mdb_versioning=True, + ) + dataset2 = verta.dataset.S3( + "s3://{}/{}".format(bucket2, key2), + enable_mdb_versioning=True, + ) + dataset = dataset1 + dataset2 + + blob_path = "data" + commit.update(blob_path, dataset) + commit.save("Version data.") + dataset = commit.get(blob_path) + + dirpath = dataset.download() + assert_dirs_match(dirpath, reference_dir) + + def test_concat_arg_mismatch_error(self): + dataset1 = verta.dataset.S3( + "s3://verta-starter/", + enable_mdb_versioning=True, + ) + dataset2 = verta.dataset.S3( + "s3://verta-versioned-bucket/", + enable_mdb_versioning=False, + ) + + with pytest.raises(ValueError): + dataset1 + dataset2 # pylint: disable=pointless-statement + @pytest.mark.usefixtures("in_tempdir") class TestPathManagedVersioning: @@ -650,3 +800,51 @@ def test_base_path(self, commit): dirpath = dataset.download() assert os.path.abspath(dirpath) != os.path.abspath(reference_dir) assert_dirs_match(dirpath, reference_dir) + + def test_concat(self, commit): + reference_dir = "tiny-files/" + os.mkdir(reference_dir) + # two .file files in tiny-files/ + for filename in ["tiny{}.file".format(i) for i in range(2)]: + with open(os.path.join(reference_dir, filename), 'wb') as f: + f.write(os.urandom(2**16)) + + # create and concatenate datasets + dataset1 = verta.dataset.Path( + "tiny-files/tiny0.file", + enable_mdb_versioning=True, + ) + dataset2 = verta.dataset.Path( + "tiny-files/tiny1.file", + enable_mdb_versioning=True, + ) + dataset = dataset1 + dataset2 + + blob_path = "data" + commit.update(blob_path, dataset) + commit.save("Version data.") + dataset = commit.get(blob_path) + + dirpath = dataset.download() + dirpath = os.path.join(dirpath, reference_dir) # "tiny-files/" nested in new dir + assert_dirs_match(dirpath, reference_dir) + + def test_concat_arg_mismatch_error(self): + reference_dir = "tiny-files/" + os.mkdir(reference_dir) + # two .file files in tiny-files/ + for filename in ["tiny{}.file".format(i) for i in range(2)]: + with open(os.path.join(reference_dir, filename), 'wb') as f: + f.write(os.urandom(2**16)) + + dataset1 = verta.dataset.Path( + "tiny-files/tiny0.file", + enable_mdb_versioning=True, + ) + dataset2 = verta.dataset.Path( + "tiny-files/tiny1.file", + enable_mdb_versioning=False, + ) + + with pytest.raises(ValueError): + dataset1 + dataset2 # pylint: disable=pointless-statement diff --git a/client/verta/verta/dataset/_dataset.py b/client/verta/verta/dataset/_dataset.py index 5cb804d948..8485de7eef 100644 --- a/client/verta/verta/dataset/_dataset.py +++ b/client/verta/verta/dataset/_dataset.py @@ -25,7 +25,7 @@ class _Dataset(blob.Blob): Base class for dataset versioning. Not for human consumption. """ - def __init__(self, enable_mdb_versioning=False): + def __init__(self, paths=None, enable_mdb_versioning=False): super(_Dataset, self).__init__() self._components_map = dict() # paths to Component objects @@ -46,6 +46,25 @@ def __repr__(self): return "\n ".join(lines) + def __add__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + + self_keys = set(self._components_map.keys()) + other_keys = set(other._components_map.keys()) + intersection = list(self_keys & other_keys) + if intersection: + raise ValueError("datasets contain overlapping paths: {}".format(intersection)) + + if self._mdb_versioned != other._mdb_versioned: + raise ValueError("datasets must have same value for `enable_mdb_versioning`") + + new = self.__class__(paths=[], enable_mdb_versioning=self._mdb_versioned) + new._components_map.update(self._components_map) + new._components_map.update(other._components_map) + + return new + @abc.abstractmethod def _prepare_components_to_upload(self): pass