Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement graph.vectorize and Blockwise Op #306

Merged
merged 7 commits into from
Sep 6, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 17, 2023

Critical for pymc-devs/pymc#5383

This PR was done almost completely from scratch, although the general principles and some of the code were obviously inspired by aesara-devs/aesara#1215

For fairness, all authors of the commits in that PR were added as co-authors here.

CC @purna135

Some design questions:

  1. Do we want to make Blockwise the "default" form of an Op (in terms of rewrites), and only later specialize if they were not needed? Decided yes
  2. Do we want to fuse consecutive Blockwises with the same batched dims, similarly to how we fuse consecutive Elemwises? Not for now
  3. Can Rop be implemented just as the batched core Rop, like L_op?

Important follow-up:

  1. Dispatch to JAX and Numba
  2. Create a PyTensor "vectorize" function that vectorizes the whole graph, similar to what the L_op method has to do Included as the last commit
from pytensor.graph import vectorize

# Original graph
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))

new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})

fn = pytensor.function([new_x], new_y)
fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096],
#        [0.66524096, 0.24472847, 0.09003057]])

Important to test:

  1. Whether this works with Ops with inner graphs (Scan, Composite)

@ricardoV94 ricardoV94 force-pushed the blockwise branch 2 times, most recently from 5776292 to 4280c20 Compare May 18, 2023 11:32
@codecov-commenter
Copy link

codecov-commenter commented May 18, 2023

Codecov Report

Merging #306 (ff1f7c5) into main (8112576) will increase coverage by 0.01%.
The diff coverage is 81.34%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #306      +/-   ##
==========================================
+ Coverage   80.39%   80.40%   +0.01%     
==========================================
  Files         156      158       +2     
  Lines       45397    45652     +255     
  Branches    11106    11181      +75     
==========================================
+ Hits        36497    36708     +211     
- Misses       6694     6722      +28     
- Partials     2206     2222      +16     
Impacted Files Coverage Δ
pytensor/link/numba/dispatch/nlinalg.py 100.00% <ø> (ø)
pytensor/tensor/basic.py 90.77% <ø> (+0.01%) ⬆️
pytensor/tensor/rewriting/linalg.py 76.92% <64.91%> (+4.70%) ⬆️
pytensor/tensor/blockwise.py 79.03% <79.03%> (ø)
pytensor/tensor/slinalg.py 92.30% <88.63%> (-1.25%) ⬇️
pytensor/tensor/elemwise.py 88.44% <92.00%> (+0.42%) ⬆️
pytensor/compile/mode.py 84.47% <100.00%> (ø)
pytensor/tensor/nlinalg.py 98.07% <100.00%> (+0.23%) ⬆️
pytensor/tensor/random/op.py 97.51% <100.00%> (+0.04%) ⬆️
pytensor/tensor/rewriting/blockwise.py 100.00% <100.00%> (ø)

... and 7 files with indirect coverage changes

@ricardoV94 ricardoV94 marked this pull request as ready for review May 18, 2023 12:45
@ricardoV94 ricardoV94 marked this pull request as draft May 25, 2023 08:31
@ricardoV94 ricardoV94 force-pushed the blockwise branch 2 times, most recently from 6660ea3 to 0103154 Compare June 22, 2023 09:39
@ricardoV94 ricardoV94 marked this pull request as ready for review June 22, 2023 09:40
@ricardoV94 ricardoV94 force-pushed the blockwise branch 5 times, most recently from 44843a4 to bedc48b Compare June 23, 2023 10:17
@ricardoV94 ricardoV94 added enhancement New feature or request major Op implementation linalg Linear algebra labels Jun 23, 2023
@ricardoV94 ricardoV94 force-pushed the blockwise branch 3 times, most recently from e13a69c to ff1f7c5 Compare June 23, 2023 13:41
Copy link
Member

@ferrine ferrine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see issues with the design approach. The rewrites become more and more involved and should take blockwise and non blockwise ops in account. I think a discussion over this issue is required before any steps forward.

pytensor/tensor/blockwise.py Outdated Show resolved Hide resolved
pytensor/tensor/elemwise.py Show resolved Hide resolved
pytensor/tensor/nlinalg.py Show resolved Hide resolved
pytensor/tensor/nlinalg.py Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Show resolved Hide resolved
pytensor/tensor/rewriting/linalg.py Show resolved Hide resolved
@ricardoV94
Copy link
Member Author

ricardoV94 commented Jul 4, 2023

@ferrine I think the sanest thing is for rewrites to only target Blockwise graphs, the same way that our rewrites only target Elemwise graphs and never the equivalent Scalar graphs (see some discussion in #107).

Users would now always work with Blockwise (that's what we expose by default when you do pt.foo), the same way they always work with Elemwise. During compilation we remove Blockwise when unnecessary the same way that #107 suggests to remove Elemwise when unnecessary, but only near the end of compilation.

What problem do you see with this?

@ferrine
Copy link
Member

ferrine commented Jul 6, 2023

Let's get more reviews

@ricardoV94 ricardoV94 force-pushed the blockwise branch 3 times, most recently from 5c7cb7f to c88f74c Compare August 25, 2023 09:53
@ricardoV94
Copy link
Member Author

ricardoV94 commented Aug 25, 2023

Last commits introduces a vectorize utility. This should be the PyTensor equivalent to np.vectorize or jax.vectorize.

If people disagree too much with the API I will drop this commit so we can discuss it elsewhere. I would like to get the Blockwise functionality merged regardless of what the final vectorize may look like, because this functionality is sorely missing.

See new example in top comment

@twiecki
Copy link
Member

twiecki commented Aug 25, 2023

I think in pymc4 we tried to use functionality like this to sample multiple chains in parallel (in TF). Sounds like we could just take a pymc model logp and vectorize it along the dim of chains. Now the question is whether that would provide any benefits.

@ricardoV94
Copy link
Member Author

I think in pymc4 we tried to use functionality like this to sample multiple chains in parallel (in TF). Sounds like we could just take a pymc model logp and vectorize it along the dim of chains. Now the question is whether that would provide any benefits.

I am not sure what numpyro vectorize does, but that sounds a bit strange? The sampler probably wants to go different tree_depths for different chains / points?

I see this as being more useful for actual model building, where you define the core case and then vectorize batched dims.

Also useful for marginalization of discrete variables via enumeration that pymc_experimental.marginal_model does.

Might also be useful for stuff like compute_log_likelihood and prior and posterior predictive sampling where we now resort to iterating over chains/draws, although I am not sure what cases would actually be faster.

@twiecki
Copy link
Member

twiecki commented Aug 25, 2023

I am not sure what numpyro vectorize does, but that sounds a bit strange? The sampler probably wants to go different tree_depths for different chains / points?

Right, for NUTS it's a pain, as @ColCarroll can attest to. I think that's where some of the new sampler developments might come in handy by Matt Hoffman et al.

I see this as being more useful for actual model building, where you define the core case and then vectorize batched dims.

Also useful for marginalization of discrete variables via enumeration that pymc_experimental.marginal_model does.

Might also be useful for stuff like compute_log_likelihood and prior and posterior predictive sampling where we now resort to iterating over chains/draws, although I am not sure what cases would actually be faster.

Interesting potential applications.

In any case, I think the API of vectorize is so simple and elegant that I find it hard to imagine someone to take issue with it. It's also very analogous to the existing numpy functionality.

@Armavica Armavica self-requested a review September 1, 2023 17:22
@ricardoV94 ricardoV94 force-pushed the blockwise branch 3 times, most recently from 46ece42 to d6b8777 Compare September 5, 2023 13:16
ricardoV94 and others added 3 commits September 5, 2023 16:07
Inspired by: aesara-devs/aesara#1215

Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com>
Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com>
Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3>
Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
@ricardoV94
Copy link
Member Author

Tests are passing again

@twiecki
Copy link
Member

twiecki commented Sep 5, 2023

Merge?

@ricardoV94
Copy link
Member Author

Merge?

I think so. Nothing like trying it out there to see if the design is good

@twiecki twiecki merged commit d611395 into pymc-devs:main Sep 6, 2023
This was referenced Sep 6, 2023
@ricardoV94 ricardoV94 changed the title Implement Blockwise Op to vectorize existing Ops Implement graph.vectorize and Blockwise Op Sep 7, 2023
@ricardoV94 ricardoV94 deleted the blockwise branch October 12, 2023 08:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request linalg Linear algebra major Op implementation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants