-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: special: JAX support (non-jitted) #22256
Conversation
crusaderky
commented
Jan 6, 2025
- Part of ENH: tracking issue for JAX support #22246
- See also EHN: cluster: JAX support (non-jitted) #22255
@@ -172,11 +171,10 @@ def _elements_and_indices_with_max_real(a, axis=-1, xp=None): | |||
# simple, array-API compatible way of doing so that doesn't | |||
# have a problem with `axis` being a tuple or None. | |||
i = xp.reshape(xp.arange(xp_size(a)), a.shape) | |||
i[~mask] = -1 | |||
i = xpx.at(i, ~mask).set(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails on jax.jit. My current intention is to change jax.jit itself to special-case arr.at[idx].set(value)
when idx is a boolean mask and value is a scalar, so that it can be rewritten as jnp.where(idx, value, arr)
. Failing that, I can implement the same special case in array-api-extra.
5a4ed4a
to
17c4aac
Compare
out[no_wrap] = x[no_wrap] | ||
return out | ||
return xp.where(no_wrap, x, wrapped) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for my understanding, could you explain why we use where
here instead of xpx.at
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to ask the same thing. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for forward compatibility with jitted jax, which doesn't support updates with bool masks. Performance implications are negligible in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it ensures compatibility with immutable arrays – that is to say, this is more important than just supporting jitted JAX, in my opionion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, it ensures compatibility with immutable arrays
But isn't that what at.set
is for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason to change this now is that in this specific case it is trivial to do so.
When it comes to xpx.at + jax.jit + boolean mask indices, there are three substantially different use cases to tackle:
- when the expression could be trivially rewritten as a
xp.where
with no performance loss (this case) - when the value is a scalar, so we could somewhat straightforwardly implement a special case rewrite rule either in jax.jit itself or inside xpx.at that converts it, for jax only, to
jnp.where(mask, scalar_value, x)
; after that change to jax/array_api_extra, the changes I'm doing now in scipy will automatically become compatible with jax.jit. - when neither is true and ad-hoc work is required, e.g.
_lazywhere
And how are we deciding when the performance impact is negligible?
Is there a benchmark suite we can run? I would be reluctant to run it manually for every PR though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rgommers can you clarify your opinion here - should the mutation be replaced by where
on this line, and if so, why is that right to do that here and not elsewhere?
You wrote "The change is because xpx.at
currently does not support __setitem__
with a bool array as index, only integer array." but AFAICT, https://github.com/scipy/scipy/pull/22256/files#r1904169337 also has a boolean array as the index. What is the difference between these two cases?
Is there a benchmark suite we can run? I would be reluctant to run it manually for every PR though.
It doesn't look like it, at least there wasn't last time that ran. Update: here are current special
benchmarks, and there is not one for logsumexp
. I wouldn't want to run it for every PR either; that's why I'm proposing that we try to maintain mutation for mutable arrays rather than deciding whether or not there is a performance cost to the copy. (TBH, I'm not particularly concerned about performance cost here; I just thought the whole purpose of changing everything to at.set
was so we could get JAX working without resorting to where
across the board.) If that's not possible with xpx.at.set
as-written, then why not improve xpx.at.set
before this sweep is done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rgommers can you clarify your opinion here - should the mutation be replaced by
where
on this line, and if so, why is that right to do that here and not elsewhere?
I spent >5 minutes reading the surrounding code and while I suspect @crusaderky is right that it doesn't matter performance-wise, it's not completely obvious to me (where
operands are arrays of the same size as the output, so it could slow things down by some %). So while I'm not attached to the final outcome for this particular case, I'd encourage fixing the design issue first so future x[mask] = ...
instances can be replaced by something that is 100% semantically equivalent, so there's no need to think at all.
If that's not possible with
xpx.at.set
as-written, then why not improvexpx.at.set
before this sweep is done?
I agree with this. I don't have a strong opinion of whether it must be inside at
or in another function, but I think it's probably better for it to be a single thing, rather than the three cases 1-3 above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this. I don't have a strong opinion of whether it must be inside at or in another function, but I think it's probably better for it to be a single thing, rather than the three cases 1-3 above.
I don't see how the use case of _lazywhere
could be made to work with a simple helper function, honestly.
As for use case 1, I could write something like
class at:
def set_masked(self, value, copy=None):
"""Same as x[idx] = value[idx]"""
but I think it would both be less readable and slower compared to x = xp.where(idx, value, x)
including on writeable backends. Benchmark here: #22246 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but AFAICT, https://github.com/scipy/scipy/pull/22256/files#r1904169337 also has a boolean array as the index. What is the difference between these two cases?
The difference is that https://github.com/scipy/scipy/pull/22256/files#r1904169337 is
a[mask] = scalar
which can be special-cased for JAX inside JAX itself or array-api-extra so that it works with the jit, and would be slower to rewrite in numpy as
a = xp.where(mask, scalar, a)
benchmark here: #22246 (comment)
This one instead is
a[mask] = b[mask]
which cannot be special-cased transparently; we could write a new method a = at(a)[mask].set_masked(b)
which not only it would hurt readability, but it would also be at all times slower on numpy than just writing
a = xp.where(mask, b, a)
benchmark here: #22246 (comment)
scipy/special/_logsumexp.py
Outdated
max_i = xp.max(i, axis=axis, keepdims=True) | ||
mask = i == max_i | ||
a = xp_copy(a) | ||
a[~mask] = 0 | ||
a = xpx.at(a, ~mask).set(0, copy=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd like to stay current with this stuff even though I'm not working on array API ATM. Can you explain when at.set
with copy=True
would be preferred to just using xp.where
? (ISTM this could/should have been written with xp.where
initially, if it wasn't really modifying an existing array. It might have been this way to avoid jumping through hoops to preserve dtype with where
, since the arguments have to be arrays ATM.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xpx.at(a, mask).set(value, copy=True)
differs from xp.where
when value
is not a scalar.
I fully agree that xp.where
is much more readable in this case; I'm updating it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a caveat regarding dtypes though:
xpx.at(a, ~mask).set(0, copy=True)
retains the dtype of a, whereas
xp.where(mask, a, 0)
broadcasts the dtypes. To retain the behaviour, you need
xp.where(mask, a, xp.zeros((), dtype=a.dtype))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that should be resolved with data-apis/array-api#860
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a caveat regarding dtypes though:
Right, that's what I was referring to with "It might have been this way to avoid jumping through hoops to preserve dtype with where
." I see that your update jumped through the hoops.
Looking forward to http://data-apis/array-api#860...
But in what other way do the two differ when value
is not a scalar? I would have though one could get the same result with xp.where
as with at.set
with copy=True
(with care taken regarding dtypes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a caveat regarding dtypes though:
Right, that's what I was referring to with "It might have been this way to avoid jumping through hoops to preserve dtype with
where
." I see that your update jumped through the hoops.Looking forward to http://data-apis/array-api#860...
I can revert it and add a comment explaning the choice if it looks more readable to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need; either way is fine. There are lots of times I've used where
and specified dtypes explicitly like. They stick out to me as ugly, so I will notice them and clean them up when it's possible. Using at.set
with copy=True
also stands out as a reminder to replace if possible. So either are fine.
Just in the future, ISTM there should be one (and preferably only one) obvious way to do a where
-like operation when copies are required. So I suppose the question is whether we'll even want to allow at.set
with copy=True
once where
supports scalars?
A broad comment here: I don't think non-jitted JAX support is all that useful to users. From the changes here, it looks like "non-jitted" actually implies "non-traceable", which means these functions would be incompatible with any tranformation ( If this change is intended as a partial step toward eventual support for traced execution, then that's great, and we should do it. if this change is intended as an end in itself, then I don't think it's worth the effort. |
yes: gh-22246 |
All CI failures are unrelated |
193cd1d
to
1e98660
Compare
1e98660
to
a3dc469
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @crusaderky !
@crusaderky I noticed that this skip remains: scipy/scipy/special/tests/test_support_alternative_backends.py Lines 68 to 69 in 9166b6f
Perhaps that could/should have been addressed here? |
|