Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: special: JAX support (non-jitted) #22256

Merged
merged 3 commits into from
Jan 7, 2025
Merged

Conversation

crusaderky
Copy link
Contributor

@@ -172,11 +171,10 @@ def _elements_and_indices_with_max_real(a, axis=-1, xp=None):
# simple, array-API compatible way of doing so that doesn't
# have a problem with `axis` being a tuple or None.
i = xp.reshape(xp.arange(xp_size(a)), a.shape)
i[~mask] = -1
i = xpx.at(i, ~mask).set(-1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails on jax.jit. My current intention is to change jax.jit itself to special-case arr.at[idx].set(value) when idx is a boolean mask and value is a scalar, so that it can be rewritten as jnp.where(idx, value, arr). Failing that, I can implement the same special case in array-api-extra.

@lucascolley lucascolley added the array types Items related to array API support and input array validation (see gh-18286) label Jan 6, 2025
Comment on lines -150 to +148
out[no_wrap] = x[no_wrap]
return out
return xp.where(no_wrap, x, wrapped)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding, could you explain why we use where here instead of xpx.at?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to ask the same thing. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for forward compatibility with jitted jax, which doesn't support updates with bool masks. Performance implications are negligible in this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it ensures compatibility with immutable arrays – that is to say, this is more important than just supporting jitted JAX, in my opionion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it ensures compatibility with immutable arrays

But isn't that what at.set is for?

Copy link
Contributor Author

@crusaderky crusaderky Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to change this now is that in this specific case it is trivial to do so.
When it comes to xpx.at + jax.jit + boolean mask indices, there are three substantially different use cases to tackle:

  1. when the expression could be trivially rewritten as a xp.where with no performance loss (this case)
  2. when the value is a scalar, so we could somewhat straightforwardly implement a special case rewrite rule either in jax.jit itself or inside xpx.at that converts it, for jax only, to jnp.where(mask, scalar_value, x); after that change to jax/array_api_extra, the changes I'm doing now in scipy will automatically become compatible with jax.jit.
  3. when neither is true and ad-hoc work is required, e.g. _lazywhere

And how are we deciding when the performance impact is negligible?

Is there a benchmark suite we can run? I would be reluctant to run it manually for every PR though.

Copy link
Contributor

@mdhaber mdhaber Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rgommers can you clarify your opinion here - should the mutation be replaced by where on this line, and if so, why is that right to do that here and not elsewhere?

You wrote "The change is because xpx.at currently does not support __setitem__ with a bool array as index, only integer array." but AFAICT, https://github.com/scipy/scipy/pull/22256/files#r1904169337 also has a boolean array as the index. What is the difference between these two cases?

Is there a benchmark suite we can run? I would be reluctant to run it manually for every PR though.

It doesn't look like it, at least there wasn't last time that ran. Update: here are current special benchmarks, and there is not one for logsumexp. I wouldn't want to run it for every PR either; that's why I'm proposing that we try to maintain mutation for mutable arrays rather than deciding whether or not there is a performance cost to the copy. (TBH, I'm not particularly concerned about performance cost here; I just thought the whole purpose of changing everything to at.set was so we could get JAX working without resorting to where across the board.) If that's not possible with xpx.at.set as-written, then why not improve xpx.at.set before this sweep is done?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rgommers can you clarify your opinion here - should the mutation be replaced by where on this line, and if so, why is that right to do that here and not elsewhere?

I spent >5 minutes reading the surrounding code and while I suspect @crusaderky is right that it doesn't matter performance-wise, it's not completely obvious to me (where operands are arrays of the same size as the output, so it could slow things down by some %). So while I'm not attached to the final outcome for this particular case, I'd encourage fixing the design issue first so future x[mask] = ... instances can be replaced by something that is 100% semantically equivalent, so there's no need to think at all.

If that's not possible with xpx.at.set as-written, then why not improve xpx.at.set before this sweep is done?

I agree with this. I don't have a strong opinion of whether it must be inside at or in another function, but I think it's probably better for it to be a single thing, rather than the three cases 1-3 above.

Copy link
Contributor Author

@crusaderky crusaderky Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this. I don't have a strong opinion of whether it must be inside at or in another function, but I think it's probably better for it to be a single thing, rather than the three cases 1-3 above.

I don't see how the use case of _lazywhere could be made to work with a simple helper function, honestly.

As for use case 1, I could write something like

class at:
    def set_masked(self, value, copy=None):
        """Same as x[idx] = value[idx]"""

but I think it would both be less readable and slower compared to x = xp.where(idx, value, x) including on writeable backends. Benchmark here: #22246 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but AFAICT, https://github.com/scipy/scipy/pull/22256/files#r1904169337 also has a boolean array as the index. What is the difference between these two cases?

The difference is that https://github.com/scipy/scipy/pull/22256/files#r1904169337 is

a[mask] = scalar

which can be special-cased for JAX inside JAX itself or array-api-extra so that it works with the jit, and would be slower to rewrite in numpy as

a = xp.where(mask, scalar, a)

benchmark here: #22246 (comment)

This one instead is

a[mask] = b[mask]

which cannot be special-cased transparently; we could write a new method a = at(a)[mask].set_masked(b) which not only it would hurt readability, but it would also be at all times slower on numpy than just writing

a = xp.where(mask, b, a)

benchmark here: #22246 (comment)

max_i = xp.max(i, axis=axis, keepdims=True)
mask = i == max_i
a = xp_copy(a)
a[~mask] = 0
a = xpx.at(a, ~mask).set(0, copy=True)
Copy link
Contributor

@mdhaber mdhaber Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to stay current with this stuff even though I'm not working on array API ATM. Can you explain when at.set with copy=True would be preferred to just using xp.where? (ISTM this could/should have been written with xp.where initially, if it wasn't really modifying an existing array. It might have been this way to avoid jumping through hoops to preserve dtype with where, since the arguments have to be arrays ATM.)

Copy link
Contributor Author

@crusaderky crusaderky Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xpx.at(a, mask).set(value, copy=True) differs from xp.where when value is not a scalar.
I fully agree that xp.where is much more readable in this case; I'm updating it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a caveat regarding dtypes though:

xpx.at(a, ~mask).set(0, copy=True)

retains the dtype of a, whereas

xp.where(mask, a, 0)

broadcasts the dtypes. To retain the behaviour, you need

xp.where(mask, a, xp.zeros((), dtype=a.dtype))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that should be resolved with data-apis/array-api#860

Copy link
Contributor

@mdhaber mdhaber Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a caveat regarding dtypes though:

Right, that's what I was referring to with "It might have been this way to avoid jumping through hoops to preserve dtype with where." I see that your update jumped through the hoops.

Looking forward to http://data-apis/array-api#860...

But in what other way do the two differ when value is not a scalar? I would have though one could get the same result with xp.where as with at.set with copy=True (with care taken regarding dtypes).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a caveat regarding dtypes though:

Right, that's what I was referring to with "It might have been this way to avoid jumping through hoops to preserve dtype with where." I see that your update jumped through the hoops.

Looking forward to http://data-apis/array-api#860...

I can revert it and add a comment explaning the choice if it looks more readable to you.

Copy link
Contributor

@mdhaber mdhaber Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need; either way is fine. There are lots of times I've used where and specified dtypes explicitly like. They stick out to me as ugly, so I will notice them and clean them up when it's possible. Using at.set with copy=True also stands out as a reminder to replace if possible. So either are fine.

Just in the future, ISTM there should be one (and preferably only one) obvious way to do a where-like operation when copies are required. So I suppose the question is whether we'll even want to allow at.set with copy=True once where supports scalars?

@jakevdp
Copy link
Member

jakevdp commented Jan 6, 2025

special: JAX support (non-jitted)

A broad comment here: I don't think non-jitted JAX support is all that useful to users. From the changes here, it looks like "non-jitted" actually implies "non-traceable", which means these functions would be incompatible with any tranformation (jit, vmap, grad, shard_map, etc.) The main utility of JAX lies in these transformations, and code that's incompatible with them will be mostly unusable by JAX users.

If this change is intended as a partial step toward eventual support for traced execution, then that's great, and we should do it.

if this change is intended as an end in itself, then I don't think it's worth the effort.

@lucascolley
Copy link
Member

If this change is intended as a partial step toward eventual support for traced execution, then that's great, and we should do it.

yes: gh-22246

@crusaderky
Copy link
Contributor Author

All CI failures are unrelated

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @crusaderky !

@lucascolley lucascolley merged commit 85c6c2a into scipy:main Jan 7, 2025
34 of 37 checks passed
@crusaderky crusaderky deleted the jax_special branch January 7, 2025 14:12
@lucascolley
Copy link
Member

@crusaderky I noticed that this skip remains:

if is_jax(xp) and f_name in {'stdtrit'}:
pytest.skip(f"`{f_name}` generic implementation require array mutation.")

Perhaps that could/should have been addressed here?

@crusaderky
Copy link
Contributor Author

@crusaderky I noticed that this skip remains:

if is_jax(xp) and f_name in {'stdtrit'}:
pytest.skip(f"`{f_name}` generic implementation require array mutation.")

Perhaps that could/should have been addressed here?

#22281

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy.special
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants