Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Enable vectorizaton / Jax #515

Open
hmgaudecker opened this issue Feb 1, 2023 · 4 comments · May be fixed by #812
Open

ENH: Enable vectorizaton / Jax #515

hmgaudecker opened this issue Feb 1, 2023 · 4 comments · May be fixed by #812
Labels
enhancement New feature or request

Comments

@hmgaudecker
Copy link
Collaborator

hmgaudecker commented Feb 1, 2023

Is your feature request related to a problem?

We merged #425 before Jax is fully usable in GETTSIM. This issue is to remind ourselves of what needs to be done still.

Describe the solution you would like to see

Core thing is to change src._gettsim.functions_loader.vectorize_func to something like:

def _vectorize_func(func):
    backend = "jax" if USE_JAX else "numpy"
    return make_vectorizable(func, backend)

Doing so currently leads to failures because some elements cannot be vectorized yet. Examples found so far:

  • calls of float(x) do not work on arrays
  • piecewise_polynomial assumes scalar input for x (even though numpy.searchsorted is more general)
  • The entire piecewise_functions module is using numpy instead of numpy_or_jax for now, which is unlikely to work with the jax backend.

Once these things are fixed, the above version of _vectorize_function should work and it will likely be faster than the current implementation already.

If, in addition, we knew which function arguments are to be vectorized over and which are not (some ideas for a decorator in GEP 6), we can call jax.vmap and make it even faster. This is more of a reminder, probably to be tackled in a separate PR.

@hmgaudecker hmgaudecker added the enhancement New feature or request label Feb 1, 2023
@hmgaudecker hmgaudecker changed the title ENH: Enable Jax ENH: Enable vectorizaton / Jax Feb 1, 2023
@hmgaudecker
Copy link
Collaborator Author

One thing to think about would be to leave the numpy backend as is, i.e., write on scalars and call np.vectorize. Seems like this is easier on people not speaking Python as their lingua franca.

hmgaudecker added a commit that referenced this issue Feb 1, 2023
* Add the vectorization module translating source code into a version that can be called with arrays.
* Add appropriate tests
* Fix most failures, though not quite there yet to enable it (See #515)

---------

Co-authored-by: Hans-Martin von Gaudecker <hmgaudecker@gmail.com>
@ChristianZimpelmann
Copy link
Member

One thing to think about would be to leave the numpy backend as is, i.e., write on scalars and call np.vectorize. Seems like this is easier on people not speaking Python as their lingua franca.

I thought this is the plan. What would be the alternative?

@hmgaudecker
Copy link
Collaborator Author

Yes, sorry, that was misleading and only comprehensible based on understanding the evolution of my line of thinking during the day.

The way @timmens has implemented make_vectorizable is that it rewrites the source code so that you can call all functions with vectors natively. Users will not notice they are not writing them for scalars, unless they use constructs that are not supported by the functions.

In all likelihood, the code will run orders of magnitude faster with that setting.

So we'll need to balance two goals

  • making it easy on people developing GETTSIM, but who are not Python experts
  • making GETTSIM run reasonably fast out of the box

The solution will probably be to allow for both kinds of backends, leaving the current state of things the default and making the numpy-vectorized-out-of-the-box option an opt-in, just as the Jax backend. We'll just need to develop some useful terminology.

@hmgaudecker
Copy link
Collaborator Author

Once that is done, remove renewed casting to numpy.datetime64 in demographic_vars.py / alter_monate. Currently required because np.vectorize() will cast to object:

> /mambaforge/envs/gettsim/lib/python3.11/site-packages/numpy/lib/function_base.py(2412)

2400         def _vectorize_call(self, func, args):                                                                                                                                                                                                                                  
[...]
2409                 # Convert args to object arrays first                                                                                                                                                                                                                           
2410                 inputs = [asanyarray(a, dtype=object) for a in args]                                                                                                                                                                                                            

@timmens timmens linked a pull request Jan 21, 2025 that will close this issue
18 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants