You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I noticed that the current implementation of _distinct_but_small affects the performance of Allegro from allegro_jax.
The following replacement works better for me and avoids using lax.scan:
def_distinct_but_small(x: jax.Array):
"""Maps the entries of x into integers from 0 to n-1 denoting unique values."""shape=x.shapex=x.ravel()
sorted_idx=jnp.argsort(x)
# Each segment of equal numbers gets a unique indexnew_group=jnp.concat([jnp.zeros(1), jnp.diff(x[sorted_idx]) >0], axis=0)
group_idx=jnp.cumsum(new_group)
# Assigns each entry of x to its corresponding unique elementx=x.at[sorted_idx].set(group_idx)
returnx.reshape(shape)
Potentially, changing this snippet could also improve the performance in this issue: mariogeiger/allegro-jax#3
The text was updated successfully, but these errors were encountered:
Hi, I noticed that the current implementation of
_distinct_but_small
affects the performance of Allegro from allegro_jax.The following replacement works better for me and avoids using
lax.scan
:Potentially, changing this snippet could also improve the performance in this issue: mariogeiger/allegro-jax#3
The text was updated successfully, but these errors were encountered: