-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update public README with project info and include new docs pointers.
PiperOrigin-RevId: 705944447
- Loading branch information
1 parent
9f6b6a5
commit eff605a
Showing
5 changed files
with
183 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,56 @@ | ||
# Grain - Feeding JAX Models | ||
|
||
Grain is a library for reading data for training and evaluating JAX models. It's | ||
open source, fast and deterministic. | ||
[](https://github.com/google/grain/actions/workflows/tests.yaml) | ||
[](https://pypi.org/project/grain/) | ||
|
||
* Installation: `pip install grain` | ||
* [Docs](https://github.com/google/grain/tree/main/docs) | ||
* Grain is used by [MaxText](https://github.com/google/maxtext/tree/main), a simple, performant and scalable JAX codebase for LLM. | ||
|
||
Check out [`tutorials/`](./tutorials) for more information on how to use Grain! | ||
[**Installation**](#installation) | ||
| [**Quickstart**](#quickstart) | ||
| [**Reference docs**](https://google-grain.readthedocs.io/en/latest/) | ||
|
||
Grain is a Python library for reading data for training and evaluating JAX | ||
models. It is flexible, fast and deterministic. | ||
|
||
Grain allows to define data processing steps in a simple declarative way: | ||
|
||
```python | ||
import grain.python as grain | ||
|
||
dataset = ( | ||
grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) | ||
.shuffle(seed=10) # Shuffles elements globally. | ||
.map(lambda x: x+1) # Maps each element. | ||
.batch(batch_size=2) # Batches consecutive elements. | ||
) | ||
|
||
for batch in dataset: | ||
# Training step. | ||
``` | ||
|
||
Grain is designed to work with JAX models but it does not require JAX to run | ||
and can be used with other frameworks as well. | ||
|
||
## Installation | ||
|
||
Grain is available on [PyPI](https://pypi.org/project/grain/) and can be | ||
installed with `pip install grain`. | ||
|
||
### Supported platforms | ||
|
||
Grain does not directly use GPU or TPU in its transformations, the processing | ||
within Grain will be done on the CPU by default. | ||
|
||
| | Linux | Mac | Windows | | ||
|---------|---------|---------|---------| | ||
| x86_64 | yes | WIP | no | | ||
| aarch64 | yes | WIP | n/a | | ||
|
||
## Quickstart | ||
|
||
- [Basic `Dataset` tutorial](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_basic_tutorial.html) | ||
|
||
## Existing users | ||
|
||
Grain is used by [MaxText](https://github.com/google/maxtext/tree/main), | ||
[kauldron](https://github.com/google-research/kauldron) and multiple internal | ||
Google projects. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.