Skip to content

Commit

Permalink
fix: readme imports
Browse files Browse the repository at this point in the history
  • Loading branch information
walln committed Sep 17, 2024
1 parent eca2277 commit f3274dd
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ pip install loadax
Loadax provides a simple interface for loading data into your training loop. Here is an example of loading data from a list of items:

```python
from loadax import DataLoader, InMemoryDataset, Batcher
from loadax import Dataloader, InMemoryDataset, Batcher

dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).build(dataset)
loader = Dataloader(batcher).batch_size(2).build(dataset)

for batch in loader:
print(batch)
Expand All @@ -37,16 +37,16 @@ for batch in loader:
A dataloader is a definition of how to load data from a dataset. It itself is stateless enabling you to define mutliple dataloaders for the same dataset, and even multipple iterators for the same dataloader.

```python
dataloader = DataLoader(batcher).batch_size(2).build(dataset)
dataloader = Dataloader(batcher).batch_size(2).build(dataset)

fast_iterator = iter(dataloader)
slow_iterator = iter(dataloader)
iter_a = iter(dataloader)
iter_b = iter(dataloader)

val = next(fast_iterator)
val = next(iter_a)
print(val)
# Output: 1

val = next(slow_iterator)
val = next(iter_b)
print(val)
# Output: 1
```
Expand All @@ -58,11 +58,11 @@ In the above examples we create an object called a batcher. A batcher is an inte
When training models, it is essential to ensure that you are not blocking the training loop and especially your accelerator(s), with IO bound tasks. Loadax provides a simple interface for prefetching data into a cache using background worker(s).

```python
from loadax import DataLoader, InMemoryDataset, Batcher
from loadax import Dataloader, InMemoryDataset, Batcher

dataset = InMemoryDataset([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).prefetch(3).build(dataset)
loader = Dataloader(batcher).batch_size(2).prefetch(3).build(dataset)

for batch in loader:
print(batch)
Expand All @@ -84,15 +84,15 @@ In the same way that the dataloader can be used to prefetch data, it can also of
In the following example we have a dataset that is slow to load an individual item due to some pre-processing. Ignore the details of the MappedDataset as we will get to that later, for now just know that it lazily transforms the data from the source dataset.

```python
from loadax import DataLoader, RangeDataset, MappedDataset, Batcher
from loadax import Dataloader, RangeDataset, MappedDataset, Batcher

def slow_fn(x):
time.sleep(0.1)
return x * 2

dataset = MappedDataset(RangeDataset(0, 10), slow_fn)
batcher = Batcher(lambda x: x)
loader = DataLoader(batcher).batch_size(2).workers(2).build(dataset)
loader = Dataloader(batcher).batch_size(2).workers(2).build(dataset)

for batch in loader:
print(batch)
Expand All @@ -116,7 +116,7 @@ Loadax also supports distributed data loading. This means that you can easily sh
With the inter-node distribution handled for you, it is now trivial to build advanced distributed training loops with paradigms such as model and data parallelism.

```python
from loadax import DataLoader, InMemoryDataset, Batcher
from loadax import Dataloader, InMemoryDataset, Batcher
from loadax.sharding_utilities import fsdp_sharding
from jax.sharding import Mesh, PartitionSpec, NamedSharding
import jax.numpy as jnp
Expand All @@ -135,7 +135,7 @@ dataset = InMemoryDataset(list(range(dataset_size)))
batcher = Batcher(lambda x: jnp.stack(x))

dataloader = (
DataLoader(batcher)
Dataloader(batcher)
.batch_size(batch_size)
.workers(2)
.prefetch(2)
Expand Down Expand Up @@ -169,7 +169,7 @@ The sharding primitives that Loadax provides are powerful as they declare the wa
Another benefit of Loadax is that the underlying shape of your data is passed through all the way into your training loop. This means you can use type hints to ensure that your data is the correct shape.

```python
from loadax import DataLoader, RangeDataset, Batcher
from loadax import Dataloader, RangeDataset, Batcher

# RangeDataset has a DatasetItem type of Int, this is a generic argument that can be supplied to any dataset
# type. We can look more into this when we get to datasets.
Expand All @@ -180,7 +180,7 @@ def my_fn(x: list[int]) -> int:
return sum(x)

batcher = Batcher(my_fn)
loader = DataLoader(batcher).batch_size(2).build(dataset)
loader = Dataloader(batcher).batch_size(2).build(dataset)

for batch in loader:
print(batch)
Expand Down Expand Up @@ -224,6 +224,10 @@ dataset = MappedDataset(base_dataset, slow_fn)

When iterating through `dataset`, the the slow_fn will be applied lazily to the underlying dataset, which in itself is lazily shuffling the range dataset. This Composable pattern allows you to build complex dataloading pipelines.

### More Features

This was just a quick tour of what Loadax has to offer. For more information, please see the [documentation](https://walln.github.io/loadax/).

#### Dataset Integrations

Loadax has a few common dataset source on the roadmap, including:
Expand All @@ -233,7 +237,3 @@ Loadax has a few common dataset source on the roadmap, including:
- HuggingFaceDataset

Feel free to open an issue if you have a use case that you would like to see included.

### Batchers

Batchers are used to define how to collate your data into batches.

0 comments on commit f3274dd

Please sign in to comment.