diff --git a/README.md b/README.md index 637a9a0..e5321e5 100644 --- a/README.md +++ b/README.md @@ -17,11 +17,11 @@ pip install loadax Loadax provides a simple interface for loading data into your training loop. Here is an example of loading data from a list of items: ```python -from loadax import DataLoader, InMemoryDataset, Batcher +from loadax import Dataloader, InMemoryDataset, Batcher dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) batcher = Batcher(lambda x: x) -loader = DataLoader(batcher).batch_size(2).build(dataset) +loader = Dataloader(batcher).batch_size(2).build(dataset) for batch in loader: print(batch) @@ -37,16 +37,16 @@ for batch in loader: A dataloader is a definition of how to load data from a dataset. It itself is stateless enabling you to define mutliple dataloaders for the same dataset, and even multipple iterators for the same dataloader. ```python -dataloader = DataLoader(batcher).batch_size(2).build(dataset) +dataloader = Dataloader(batcher).batch_size(2).build(dataset) -fast_iterator = iter(dataloader) -slow_iterator = iter(dataloader) +iter_a = iter(dataloader) +iter_b = iter(dataloader) -val = next(fast_iterator) +val = next(iter_a) print(val) # Output: 1 -val = next(slow_iterator) +val = next(iter_b) print(val) # Output: 1 ``` @@ -58,11 +58,11 @@ In the above examples we create an object called a batcher. A batcher is an inte When training models, it is essential to ensure that you are not blocking the training loop and especially your accelerator(s), with IO bound tasks. Loadax provides a simple interface for prefetching data into a cache using background worker(s). ```python -from loadax import DataLoader, InMemoryDataset, Batcher +from loadax import Dataloader, InMemoryDataset, Batcher dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) batcher = Batcher(lambda x: x) -loader = DataLoader(batcher).batch_size(2).prefetch(3).build(dataset) +loader = Dataloader(batcher).batch_size(2).prefetch(3).build(dataset) for batch in loader: print(batch) @@ -84,7 +84,7 @@ In the same way that the dataloader can be used to prefetch data, it can also of In the following example we have a dataset that is slow to load an individual item due to some pre-processing. Ignore the details of the MappedDataset as we will get to that later, for now just know that it lazily transforms the data from the source dataset. ```python -from loadax import DataLoader, RangeDataset, MappedDataset, Batcher +from loadax import Dataloader, RangeDataset, MappedDataset, Batcher def slow_fn(x): time.sleep(0.1) @@ -92,7 +92,7 @@ def slow_fn(x): dataset = MappedDataset(RangeDataset(0, 10), slow_fn) batcher = Batcher(lambda x: x) -loader = DataLoader(batcher).batch_size(2).workers(2).build(dataset) +loader = Dataloader(batcher).batch_size(2).workers(2).build(dataset) for batch in loader: print(batch) @@ -116,7 +116,7 @@ Loadax also supports distributed data loading. This means that you can easily sh With the inter-node distribution handled for you, it is now trivial to build advanced distributed training loops with paradigms such as model and data parallelism. ```python -from loadax import DataLoader, InMemoryDataset, Batcher +from loadax import Dataloader, InMemoryDataset, Batcher from loadax.sharding_utilities import fsdp_sharding from jax.sharding import Mesh, PartitionSpec, NamedSharding import jax.numpy as jnp @@ -135,7 +135,7 @@ dataset = InMemoryDataset(list(range(dataset_size))) batcher = Batcher(lambda x: jnp.stack(x)) dataloader = ( - DataLoader(batcher) + Dataloader(batcher) .batch_size(batch_size) .workers(2) .prefetch(2) @@ -169,7 +169,7 @@ The sharding primitives that Loadax provides are powerful as they declare the wa Another benefit of Loadax is that the underlying shape of your data is passed through all the way into your training loop. This means you can use type hints to ensure that your data is the correct shape. ```python -from loadax import DataLoader, RangeDataset, Batcher +from loadax import Dataloader, RangeDataset, Batcher # RangeDataset has a DatasetItem type of Int, this is a generic argument that can be supplied to any dataset # type. We can look more into this when we get to datasets. @@ -180,7 +180,7 @@ def my_fn(x: list[int]) -> int: return sum(x) batcher = Batcher(my_fn) -loader = DataLoader(batcher).batch_size(2).build(dataset) +loader = Dataloader(batcher).batch_size(2).build(dataset) for batch in loader: print(batch) @@ -224,6 +224,10 @@ dataset = MappedDataset(base_dataset, slow_fn) When iterating through `dataset`, the the slow_fn will be applied lazily to the underlying dataset, which in itself is lazily shuffling the range dataset. This Composable pattern allows you to build complex dataloading pipelines. +### More Features + +This was just a quick tour of what Loadax has to offer. For more information, please see the [documentation](https://walln.github.io/loadax/). + #### Dataset Integrations Loadax has a few common dataset source on the roadmap, including: @@ -233,7 +237,3 @@ Loadax has a few common dataset source on the roadmap, including: - HuggingFaceDataset Feel free to open an issue if you have a use case that you would like to see included. - -### Batchers - -Batchers are used to define how to collate your data into batches.