Philip Mocz (2024)
Flatiron Institute
Benchmarking on macbook
(Apple M3 Max) and rusty
(Nvidia A100)
euler.py
simple JAX version on single nodeeuler_distributed.py
JAX version for distributed systemseuler_numpy.py
simple numpy version (based on my blog tutorial)
- Create a python virtual environment and install required modules:
python -m venv --system-site-packages $VENVDIR/my-jax-venv
source $VENVDIR/my-jax-venv/bin/activate
pip install -r requirements.txt
16384^2 resolution JAX (single-precision) simulation after 277300 iterations on 16 GPUs in 64.1 minutes
(for reference, my macbook run (single-precision) at 1024^2 resolution after 15426 iterations took 4.6 minutes)
The GPU calculations had a throughput (mcups) 335x more!