Skip to content

Commit

Permalink
Avoid depending on JAX internals, which are about to change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689326802
  • Loading branch information
dougalm authored and pax authors committed Oct 24, 2024
1 parent 572b571 commit 81aa951
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions praxis/base_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,7 @@
# = ShardedDeviceArray([ True], dtype=bool)
def is_running_under_pmap() -> bool:
"""Whether currently running under pmap with PMAP_PARALLEL_AXIS_NAME."""
try:
_ = jax.lax.axis_index(PMAP_PARALLEL_AXIS_NAME)
return True
except NameError:
return False
return PMAP_PARALLEL_AXIS_NAME in jax.core.unsafe_get_axis_names_DO_NOT_USE()


class WeightHParamsCollection:
Expand Down

0 comments on commit 81aa951

Please sign in to comment.