Skip to content

Commit

Permalink
Addressed feedback and added unit test for load_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
minhtuevo committed Jul 3, 2024
1 parent a4fa2b9 commit 87b446d
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
18 changes: 11 additions & 7 deletions fiftyone/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
logger = logging.getLogger(__name__)


class MissingDatasetError(ValueError):
"""Exception raised when a dataset is missing."""
class DatasetNotFoundError(ValueError):
"""Exception raised when a dataset is not found."""
def __init__(self, name):
self._dataset_name = name
super().__init__(f"Dataset {name} not found")
Expand Down Expand Up @@ -130,7 +130,7 @@ def _validate_dataset_name(name, skip=None):
return slug


def load_dataset(name, create_if_missing=False):
def load_dataset(name, create_if_necessary=False):
"""Loads the FiftyOne dataset with the given name.
To create a new dataset, use the :class:`Dataset` constructor.
Expand All @@ -143,17 +143,21 @@ def load_dataset(name, create_if_missing=False):
Args:
name: the name of the dataset
create_if_missing (False): if no dataset exists, create empty one
create_if_necessary (False): if no dataset exists, create an empty one
Raises:
DatasetNotFoundError: if the dataset does not exist and
`create_if_necessary` is False
Returns:
a :class:`Dataset`
"""
if dataset_exists(name):
return Dataset(name, _create=False)
elif create_if_missing:
elif create_if_necessary:
return Dataset(name)
else:
raise MissingDatasetError(name)
raise DatasetNotFoundError(name)


def get_default_dataset_name():
Expand Down Expand Up @@ -7604,7 +7608,7 @@ def _do_load_dataset(obj, name):
db = foo.get_db_conn()
res = db.datasets.find_one({"name": name})
if not res:
raise MissingDatasetError(name)
raise DatasetNotFoundError(name)
dataset_doc = foo.DatasetDocument.from_dict(res)

sample_collection_name = dataset_doc.sample_collection_name
Expand Down
27 changes: 27 additions & 0 deletions tests/unittests/dataset_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,33 @@ def test_dataset_names(self):
self.assertEqual(dataset.name, name)
self.assertEqual(dataset.slug, slug)

@drop_datasets
def test_load_dataset(self):
new_dataset_name = "new-dataset-name"

# validate that the dataset does not exist
names = fo.list_datasets()
assert new_dataset_name not in names

# create the dataset by attempting to load it
dataset = fo.load_dataset(new_dataset_name, create_if_necessary=True)
assert dataset.name == new_dataset_name

dataset2 = fo.load_dataset(new_dataset_name)
self.assertIs(dataset, dataset2)

# validate that the new dataset is in the list of datasets
names = fo.list_datasets()
assert new_dataset_name in names

# validate that the dataset does not exist
new_dataset_name_2 = "new-dataset-name-2"
assert new_dataset_name_2 not in names

# validate that the correct exception is raised
with self.assertRaises(fo.core.dataset.DatasetNotFoundError):
fo.load_dataset(new_dataset_name_2)

@drop_datasets
def test_delete_dataset(self):
IGNORED_DATASET_NAMES = fo.list_datasets()
Expand Down
14 changes: 11 additions & 3 deletions tests/unittests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,9 +459,11 @@ def test_multiple_config_cleanup(self):


class TestLoadDataset(unittest.TestCase):
@patch("fiftyone.core.dataset.dataset_exists")
@patch("fiftyone.core.odm.get_db_conn")
@patch("fiftyone.core.dataset.Dataset")
def test_load_dataset_by_id(self, mock_dataset, mock_get_db_conn):
def test_load_dataset_by_id(self, mock_dataset, mock_get_db_conn,
dataset_exists):
# Setup
identifier = ObjectId()
mock_db = MagicMock()
Expand All @@ -470,6 +472,7 @@ def test_load_dataset_by_id(self, mock_dataset, mock_get_db_conn):
"_id": ObjectId(identifier),
"name": "test_dataset",
}
dataset_exists.return_value = True

# Test
result = load_dataset(id=identifier)
Expand All @@ -482,9 +485,11 @@ def test_load_dataset_by_id(self, mock_dataset, mock_get_db_conn):

self.assertEqual(result, mock_dataset.return_value)

@patch("fiftyone.core.dataset.dataset_exists")
@patch("fiftyone.core.odm.get_db_conn")
@patch("fiftyone.core.dataset.Dataset")
def test_load_dataset_by_alt_id(self, mock_dataset, mock_get_db_conn):
def test_load_dataset_by_alt_id(self, mock_dataset, mock_get_db_conn,
dataset_exists):
# Setup
identifier = "alt_id"
mock_db = MagicMock()
Expand All @@ -493,6 +498,7 @@ def test_load_dataset_by_alt_id(self, mock_dataset, mock_get_db_conn):
"_id": "identifier",
"name": "dataset_name",
}
dataset_exists.return_value = True

# Test
result = load_dataset(id=identifier)
Expand All @@ -504,11 +510,13 @@ def test_load_dataset_by_alt_id(self, mock_dataset, mock_get_db_conn):
)
self.assertEqual(result, mock_dataset.return_value)

@patch("fiftyone.core.dataset.dataset_exists")
@patch("fiftyone.core.dataset.Dataset")
def test_load_dataset_by_name(self, mock_dataset):
def test_load_dataset_by_name(self, mock_dataset, dataset_exists):
# Setup
identifier = "test_dataset"
mock_dataset.return_value = {"_id": ObjectId(), "name": identifier}
dataset_exists.return_value = True

# Test
result = load_dataset(name=identifier)
Expand Down

0 comments on commit 87b446d

Please sign in to comment.