Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Issues with _distinct_but_small #99

Open
pl-fuchs opened this issue Feb 6, 2025 · 0 comments
Open

Performance Issues with _distinct_but_small #99

pl-fuchs opened this issue Feb 6, 2025 · 0 comments

Comments

@pl-fuchs
Copy link

pl-fuchs commented Feb 6, 2025

Hi, I noticed that the current implementation of _distinct_but_small affects the performance of Allegro from allegro_jax.

The following replacement works better for me and avoids using lax.scan:

def _distinct_but_small(x: jax.Array):
    """Maps the entries of x into integers from 0 to n-1 denoting unique values."""

    shape = x.shape
    x = x.ravel()
    sorted_idx = jnp.argsort(x)

    # Each segment of equal numbers gets a unique index
    new_group = jnp.concat([jnp.zeros(1), jnp.diff(x[sorted_idx]) > 0], axis=0)
    group_idx = jnp.cumsum(new_group)

    # Assigns each entry of x to its corresponding unique element
    x = x.at[sorted_idx].set(group_idx)
    return x.reshape(shape)

Potentially, changing this snippet could also improve the performance in this issue: mariogeiger/allegro-jax#3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant