-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement graph.vectorize
and Blockwise
Op
#306
Conversation
5776292
to
4280c20
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #306 +/- ##
==========================================
+ Coverage 80.39% 80.40% +0.01%
==========================================
Files 156 158 +2
Lines 45397 45652 +255
Branches 11106 11181 +75
==========================================
+ Hits 36497 36708 +211
- Misses 6694 6722 +28
- Partials 2206 2222 +16
|
6660ea3
to
0103154
Compare
44843a4
to
bedc48b
Compare
e13a69c
to
ff1f7c5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see issues with the design approach. The rewrites become more and more involved and should take blockwise and non blockwise ops in account. I think a discussion over this issue is required before any steps forward.
@ferrine I think the sanest thing is for rewrites to only target Blockwise graphs, the same way that our rewrites only target Elemwise graphs and never the equivalent Scalar graphs (see some discussion in #107). Users would now always work with Blockwise (that's what we expose by default when you do pt.foo), the same way they always work with Elemwise. During compilation we remove Blockwise when unnecessary the same way that #107 suggests to remove Elemwise when unnecessary, but only near the end of compilation. What problem do you see with this? |
Let's get more reviews |
5c7cb7f
to
c88f74c
Compare
Last commits introduces a If people disagree too much with the API I will drop this commit so we can discuss it elsewhere. I would like to get the Blockwise functionality merged regardless of what the final See new example in top comment |
I think in pymc4 we tried to use functionality like this to sample multiple chains in parallel (in TF). Sounds like we could just take a pymc model logp and vectorize it along the dim of chains. Now the question is whether that would provide any benefits. |
I am not sure what numpyro vectorize does, but that sounds a bit strange? The sampler probably wants to go different tree_depths for different chains / points? I see this as being more useful for actual model building, where you define the core case and then vectorize batched dims. Also useful for marginalization of discrete variables via enumeration that Might also be useful for stuff like |
Right, for NUTS it's a pain, as @ColCarroll can attest to. I think that's where some of the new sampler developments might come in handy by Matt Hoffman et al.
Interesting potential applications. In any case, I think the API of vectorize is so simple and elegant that I find it hard to imagine someone to take issue with it. It's also very analogous to the existing numpy functionality. |
d5a5674
to
02858ac
Compare
46ece42
to
d6b8777
Compare
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
d6b8777
to
f19f95e
Compare
Tests are passing again |
Merge? |
I think so. Nothing like trying it out there to see if the design is good |
graph.vectorize
and Blockwise
Op
Critical for pymc-devs/pymc#5383
This PR was done almost completely from scratch, although the general principles and some of the code were obviously inspired by aesara-devs/aesara#1215
For fairness, all authors of the commits in that PR were added as co-authors here.
CC @purna135
Some design questions:
Important follow-up:
Create a PyTensor "vectorize" function that vectorizes the whole graph, similar to what theIncluded as the last commitL_op
method has to doImportant to test: