From 24c4cb71ba20df27e02e28b9f0e58cdd4e03cb8f Mon Sep 17 00:00:00 2001 From: Nick Wall <46641379+walln@users.noreply.github.com> Date: Wed, 2 Oct 2024 14:31:46 -0500 Subject: [PATCH] chore: simplify imports --- src/loadax/__init__.py | 1 + tests/dataset/test_batch_mapped.py | 2 +- tests/dataset/test_combined.py | 3 +-- tests/dataset/test_huggingface_dataset.py | 2 +- tests/dataset/test_mapped.py | 2 +- tests/dataset/test_partial.py | 3 +-- tests/dataset/test_sampled.py | 3 +-- tests/dataset/test_sharded_dataset.py | 3 +-- tests/dataset/test_simple_dataset.py | 2 +- tests/test_dataloader.py | 3 +-- 10 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/loadax/__init__.py b/src/loadax/__init__.py index 4c6f040..b319471 100644 --- a/src/loadax/__init__.py +++ b/src/loadax/__init__.py @@ -12,6 +12,7 @@ from loadax.dataset.dataset import Dataset as Dataset from loadax.dataset.huggingface import HuggingFaceDataset as HuggingFaceDataset from loadax.dataset.partial_dataset import PartialDataset as PartialDataset +from loadax.dataset.sampled_dataset import SampledDataset as SampledDataset from loadax.dataset.sharded_dataset import ShardedDataset as ShardedDataset from loadax.dataset.shuffled_dataset import Shuffleable as Shuffleable from loadax.dataset.simple import SimpleDataset as SimpleDataset diff --git a/tests/dataset/test_batch_mapped.py b/tests/dataset/test_batch_mapped.py index 5c788a2..c4720e8 100644 --- a/tests/dataset/test_batch_mapped.py +++ b/tests/dataset/test_batch_mapped.py @@ -1,8 +1,8 @@ import jax.numpy as jnp import pytest +from loadax import SimpleDataset from loadax.dataset.dataset import MappedBatchDataset -from loadax.dataset.simple import SimpleDataset @pytest.fixture diff --git a/tests/dataset/test_combined.py b/tests/dataset/test_combined.py index 0e4e5d6..a9f994b 100644 --- a/tests/dataset/test_combined.py +++ b/tests/dataset/test_combined.py @@ -1,7 +1,6 @@ import pytest -from loadax.dataset.combined_dataset import CombinedDataset -from loadax.dataset.simple import SimpleDataset +from loadax import CombinedDataset, SimpleDataset @pytest.fixture diff --git a/tests/dataset/test_huggingface_dataset.py b/tests/dataset/test_huggingface_dataset.py index 135520e..3c2ef5e 100644 --- a/tests/dataset/test_huggingface_dataset.py +++ b/tests/dataset/test_huggingface_dataset.py @@ -1,7 +1,7 @@ import pytest from datasets import Dataset as HFDataset -from loadax.dataset.huggingface import HuggingFaceDataset +from loadax import HuggingFaceDataset @pytest.fixture diff --git a/tests/dataset/test_mapped.py b/tests/dataset/test_mapped.py index decb99e..94de093 100644 --- a/tests/dataset/test_mapped.py +++ b/tests/dataset/test_mapped.py @@ -1,8 +1,8 @@ import jax.numpy as jnp import pytest +from loadax import SimpleDataset from loadax.dataset.dataset import MappedDataset -from loadax.dataset.simple import SimpleDataset @pytest.fixture diff --git a/tests/dataset/test_partial.py b/tests/dataset/test_partial.py index f67701a..71f7100 100644 --- a/tests/dataset/test_partial.py +++ b/tests/dataset/test_partial.py @@ -1,7 +1,6 @@ import pytest -from loadax.dataset.partial_dataset import PartialDataset -from loadax.dataset.simple import SimpleDataset +from loadax import PartialDataset, SimpleDataset @pytest.fixture diff --git a/tests/dataset/test_sampled.py b/tests/dataset/test_sampled.py index cdf6fae..5492f72 100644 --- a/tests/dataset/test_sampled.py +++ b/tests/dataset/test_sampled.py @@ -1,8 +1,7 @@ import jax import pytest -from loadax.dataset.sampled_dataset import SampledDataset -from loadax.dataset.simple import SimpleDataset +from loadax import SampledDataset, SimpleDataset @pytest.fixture diff --git a/tests/dataset/test_sharded_dataset.py b/tests/dataset/test_sharded_dataset.py index 82ef5e5..e28e4d7 100644 --- a/tests/dataset/test_sharded_dataset.py +++ b/tests/dataset/test_sharded_dataset.py @@ -4,8 +4,7 @@ import pytest -from loadax.dataset.sharded_dataset import ShardedDataset -from loadax.dataset.simple import SimpleDataset +from loadax import ShardedDataset, SimpleDataset def compute_expected_boundaries( diff --git a/tests/dataset/test_simple_dataset.py b/tests/dataset/test_simple_dataset.py index 9757904..e996c86 100644 --- a/tests/dataset/test_simple_dataset.py +++ b/tests/dataset/test_simple_dataset.py @@ -1,7 +1,7 @@ import jax import pytest -from loadax.dataset.simple import SimpleDataset +from loadax import SimpleDataset @pytest.fixture diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 509369a..1c2b8d8 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -1,7 +1,6 @@ import pytest -from loadax.dataloader.loader import Dataloader -from loadax.dataset.simple import SimpleDataset +from loadax import Dataloader, SimpleDataset @pytest.fixture