-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Enable vectorizaton / Jax #515
Comments
One thing to think about would be to leave the numpy backend as is, i.e., write on scalars and call |
* Add the vectorization module translating source code into a version that can be called with arrays. * Add appropriate tests * Fix most failures, though not quite there yet to enable it (See #515) --------- Co-authored-by: Hans-Martin von Gaudecker <hmgaudecker@gmail.com>
I thought this is the plan. What would be the alternative? |
Yes, sorry, that was misleading and only comprehensible based on understanding the evolution of my line of thinking during the day. The way @timmens has implemented In all likelihood, the code will run orders of magnitude faster with that setting. So we'll need to balance two goals
The solution will probably be to allow for both kinds of backends, leaving the current state of things the default and making the numpy-vectorized-out-of-the-box option an opt-in, just as the Jax backend. We'll just need to develop some useful terminology. |
Once that is done, remove renewed casting to
|
Is your feature request related to a problem?
We merged #425 before Jax is fully usable in GETTSIM. This issue is to remind ourselves of what needs to be done still.
Describe the solution you would like to see
Core thing is to change
src._gettsim.functions_loader.vectorize_func
to something like:Doing so currently leads to failures because some elements cannot be vectorized yet. Examples found so far:
float(x)
do not work on arrayspiecewise_polynomial
assumes scalar input forx
(even thoughnumpy.searchsorted
is more general)piecewise_functions
module is usingnumpy
instead ofnumpy_or_jax
for now, which is unlikely to work with the jax backend.Once these things are fixed, the above version of
_vectorize_function
should work and it will likely be faster than the current implementation already.If, in addition, we knew which function arguments are to be vectorized over and which are not (some ideas for a decorator in GEP 6), we can call jax.vmap and make it even faster. This is more of a reminder, probably to be tackled in a separate PR.
The text was updated successfully, but these errors were encountered: