Skip to content

A simple JAX solver for the 2D Euler equations for benchmarking

License

Notifications You must be signed in to change notification settings

pmocz/jax-euler-benchmarks

Repository files navigation

Simple Euler Equation JAX benchmarking

Philip Mocz (2024)

Flatiron Institute

Benchmarking on macbook (Apple M3 Max) and rusty (Nvidia A100)

Files

  • euler.py simple JAX version on single node
  • euler_distributed.py JAX version for distributed systems
  • euler_numpy.py simple numpy version (based on my blog tutorial)

Setup

  • Create a python virtual environment and install required modules:
python -m venv --system-site-packages $VENVDIR/my-jax-venv
source $VENVDIR/my-jax-venv/bin/activate
pip install -r requirements.txt

Strong Scaling on macbook:

strong scaling

Weak Scaling on rusty:

weak scaling

Final Simulation Result

16384^2 resolution JAX (single-precision) simulation after 277300 iterations on 16 GPUs in 64.1 minutes

(for reference, my macbook run (single-precision) at 1024^2 resolution after 15426 iterations took 4.6 minutes)

The GPU calculations had a throughput (mcups) 335x more!

final snapshot

About

A simple JAX solver for the 2D Euler equations for benchmarking

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published