-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matrix support #6
Comments
@Sohl-Dickstein seems to be in the camp of wanting support for multiple tree axes, to compute things like Jacobians, Hessians and covariances/correlations between pytrees. |
Weighing in here at @shoyer's request -- I'd put myself in the multiple-tree-axes camp as well, I think. (I am saying that primarily as an end user rather than a developer for the feature, though... !) This starts to open up notions of "labelled tree axes". For example, performing the pytree-axes equivalent of For a thorny reference problem in which this kind of stuff might get quite useful (tree-matrix vs tree-ndarray or otherwise), I'd suggest the Ito version of Milstein's method available here: This is a pretty tricky implementation! It's very heavily annotated with comments descrbing the various tree-axes, normal-axes, and the way in which they interact. |
+1 to the @Sohl-Dickstein use case. Some more detail of where this would be handy: I recently needed to invert a Hessian of a function that took a pytree as its argument. The headache I ran into was that when I used jax.jacfwd(jax.jacrev(f))(x) to compute the Hessian, I got it as a pytree of pytrees, which turned out to be pretty complicated to flatten. It would be nice to be able to either transform a pytree of pytrees to and from a matrix of floats or to be able to perform matrix operations directly on the pytree of pytrees. |
Just wanted to chime in and say that I'd love this feature, and for my use cases (which are primarily about numerical solvers for non-convex problems) a single axis is all I'd need, though I'm sure I'd find uses in multi-axis implementation if that does get developed. |
So it's not a documented feature, but Equinox actually has a tree-math like sublibrary built-in, which can be used to do this kind of multi-axis stuff. To set the scene, here is how it is used just to broadcast vector operations together: from equinox.internal import ω
vector1 = [0, 1, (2, 3)]
vector2 = [4, 5, (6, 7)]
summed = (ω(vector1) + ω(vector2)).ω
# Alternate notation; I prefer this when doing pure arithmetic:
summed = (vector1**ω + vector2**ω).ω
print(summed) # [4, 6, (8, 10)] But with a bit of thinking you can nest these to accomplish higher-order operations: # matrix has shape (2, 3)
matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])])
# vector has shape (3,)
vector = ω([6, 7, 8])
# product (2, 3) @ (3,) -> (2,) ("call" applies the specified function to every leaf of its pytree)
matvec = matrix.call(lambda row: sum((row * vector).ω))
# unwrap
matvec = matvec.ω
print(matvec) # [23, 86] The reason this works is that ω is not a PyTree. This means that Conversely, this does mean that you musn't pass ω objects across JIT/grad/etc. API boundaries. (Whilst you can with tree-math.) ω is only meant to be used as a convenient syntax with the bounds of a single function. |
I do still think matrix support would be awesome to have, and I actually had a use-cases for this just last week. That said, at this point I'm relatively unlikely to work on it. It somebody else wants to give this a try that would be very welcome! |
@geoff-davis, your statement about "pytrees of pytrees" gave me to idea to formulate this snippet: from operator import add
from functools import partial
def inner(x: tm.Vector, y: tm.Vector, x_row_ndim) -> tm.Vector:
def product_sum(x, y):
return (x.reshape(x.shape + (1,) * (y.ndim - x.ndim + x_row_ndim)) * y).sum(axis=range(x_row_ndim, x.ndim))
def product_sum_map(x, y: tm.Vector):
return tm.Vector(jax.tree.map(partial(product_sum, x), y.tree))
z = jax.tree.map(product_sum_map, x.tree, y.tree, is_leaf=lambda x: isinstance(x, jnp.ndarray) or isinstance(x, tm.Vector))
return jax.tree.reduce(add, z, is_leaf=lambda x: isinstance(x, tm.Vector))
def outer(a: tm.Vector, b: tm.Vector, a_row_ndim: tm.Vector) -> tm.Vector:
in_leaves, a_outer_structure = jax.tree.flatten(a.tree, is_leaf=lambda x: isinstance(x, tm.Vector))
row_ndims, _ = jax.tree.flatten(a_row_ndim.tree)
out_leaves = []
for in_leaf, row_ndim in zip(in_leaves, row_ndims):
out_leaf = inner(in_leaf, b, row_ndim)
out_leaves.append(out_leaf)
return tm.Vector(jax.tree.unflatten(a_outer_structure, out_leaves)) You're limited to Does JAX have a way of accessing a leaf/subtree by its Full code with tests: https://colab.research.google.com/drive/1AE_uTA0XuScd0TQdbN4v-Lzx8Dcrq8JX?usp=sharing Update: I wrote a general version using a recursive bread-first search to flatten the input VectorMixin based on a guide tree. It works great! Now I need to wrap it all in a Update: I am ready to submit a pull request. The only thing left might be to have the product of |
@deasmhumhna this looks very cool! Could you please share a docstring explaining the precise data model for Right now, it looks like your implementation covers the "tree of trees" use case well, e.g., for representing Hessian matrices. The other use that is important to handle (possibly more common) is the "stack of trees" case, e.g., for representing a finite list of In principle I think the "stack of trees" should be a special case of your more general Matrix, with a trivial I agree that it would be a good idea for |
Stacking trees definitely works, both as a list or via batching dimensions
(`row_ndims`); there's an example at the end of the colab.
I'm currently focusing on getting the `MatrixMixin` class to work for
general custom classes. Essentially, I need a more better way to deal with
pytree wrapping since custom nodes include all their children in a flat
list. So likely `outer_treedef` will become a list of all child outer
structures called `outer_treedefs`, which in the case of `Matrix` is just a
single pytree. I also need to figure out how transposition might work in
this case, since the outer structure is defined by the class fields while
the inner structure is arbitrary so the transposed structure cannot be held
within the same type of object. I think the easiest way is to have a R/C
flag (similar to numpy) and not bother with transposing the structures.
This would work for matvecs easily, and multiplication between matrices
with different row vs column representation would use either the row-column
dot product formulation or the outer product sum formulation of matrix
multiplication rather than the row/column product sum formulation I've
already written.
I think a lot of the functionality you mentioned is possible by subtyping
the general matrix class, but it just depends on whether the goal of
`tree_math` is convenience or flexibility. General indexing is definitely
possible, but will be more complicated than the "flatten, index, unflatten"
process that will work for batched pytrees.
I think the best thing would be to return a `VectorMixin` of the same type
as the input vector, which will work for `Vector` but likely fail for
custom classes unless the matrix object has the same fields as the vector
object.
I'll work on this again this weekend and start working on docstrings too.
…On Fri, Jan 3, 2025, 10:00 AM Stephan Hoyer ***@***.***> wrote:
@deasmhumhna <https://github.com/deasmhumhna> this looks very cool!
Could you please share a docstring explaining the precise data model for
Matrix/MatrixMixin? It would be nice to include a few simple examples
(also to help me understand exactly what you've built), since it is a
little more complex than Vector.
Right now, it looks like your implementation covers the "tree of trees"
use case well, e.g., for representing Hessian matrices. The other use that
is important to handle (possibly more common) is the "stack of trees" case,
e.g., for representing a finite list of k vectors, such as the memory
terms in L-BFGS <https://en.wikipedia.org/wiki/Limited-memory_BFGS> or a multi-step
ODE solver <https://en.wikipedia.org/wiki/Linear_multistep_method>.
In principle I think the "stack of trees" should be a special case of your
more general Matrix, with a trivial outer_treedef and row_ndims=1, but it
would good to verify this, and also perhaps worthwhile to add some
helpers/constructors to ensure this works well. For example, it would be
nice if matrix[i, :] returned a vector, and if there was a method to
create a matrix of all zeros with k rows matching the structure of a
given vector.
I agree that it would be a good idea for matrix @ vector to result in
Vector. This would be a good idea to include in the initial pull request,
because otherwise it will be a breaking change at some point in the future.
—
Reply to this email directly, view it on GitHub
<#6 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AA2O5MPVNO6STKOVZWXBQ6T2I3F5PAVCNFSM6AAAAABUH3AAC6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKNRZGYYTIMRYGU>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Do you have use-cases that come to mind for using MatrixMixin with general
custom classes? There are clear use-cases for VectorMixin (e.g., a custom
state vector for a particular model) but I'm less sure for matrices. If we
don't need it, then I wouldn't worry about it.
On Fri, Jan 3, 2025 at 9:31 PM deas-mhumhna ***@***.***>
wrote:
… Stacking trees definitely works, both as a list or via batching dimensions
(`row_ndims`); there's an example at the end of the colab.
I'm currently focusing on getting the `MatrixMixin` class to work for
general custom classes. Essentially, I need a more better way to deal with
pytree wrapping since custom nodes include all their children in a flat
list. So likely `outer_treedef` will become a list of all child outer
structures called `outer_treedefs`, which in the case of `Matrix` is just
a
single pytree. I also need to figure out how transposition might work in
this case, since the outer structure is defined by the class fields while
the inner structure is arbitrary so the transposed structure cannot be
held
within the same type of object. I think the easiest way is to have a R/C
flag (similar to numpy) and not bother with transposing the structures.
This would work for matvecs easily, and multiplication between matrices
with different row vs column representation would use either the
row-column
dot product formulation or the outer product sum formulation of matrix
multiplication rather than the row/column product sum formulation I've
already written.
I think a lot of the functionality you mentioned is possible by subtyping
the general matrix class, but it just depends on whether the goal of
`tree_math` is convenience or flexibility. General indexing is definitely
possible, but will be more complicated than the "flatten, index,
unflatten"
process that will work for batched pytrees.
I think the best thing would be to return a `VectorMixin` of the same type
as the input vector, which will work for `Vector` but likely fail for
custom classes unless the matrix object has the same fields as the vector
object.
I'll work on this again this weekend and start working on docstrings too.
On Fri, Jan 3, 2025, 10:00 AM Stephan Hoyer ***@***.***>
wrote:
> @deasmhumhna <https://github.com/deasmhumhna> this looks very cool!
>
> Could you please share a docstring explaining the precise data model for
> Matrix/MatrixMixin? It would be nice to include a few simple examples
> (also to help me understand exactly what you've built), since it is a
> little more complex than Vector.
>
> Right now, it looks like your implementation covers the "tree of trees"
> use case well, e.g., for representing Hessian matrices. The other use
that
> is important to handle (possibly more common) is the "stack of trees"
case,
> e.g., for representing a finite list of k vectors, such as the memory
> terms in L-BFGS <https://en.wikipedia.org/wiki/Limited-memory_BFGS> or
a multi-step
> ODE solver <https://en.wikipedia.org/wiki/Linear_multistep_method>.
>
> In principle I think the "stack of trees" should be a special case of
your
> more general Matrix, with a trivial outer_treedef and row_ndims=1, but
it
> would good to verify this, and also perhaps worthwhile to add some
> helpers/constructors to ensure this works well. For example, it would be
> nice if matrix[i, :] returned a vector, and if there was a method to
> create a matrix of all zeros with k rows matching the structure of a
> given vector.
>
> I agree that it would be a good idea for matrix @ vector to result in
> Vector. This would be a good idea to include in the initial pull
request,
> because otherwise it will be a breaking change at some point in the
future.
>
> —
> Reply to this email directly, view it on GitHub
> <#6 (comment)>,
> or unsubscribe
> <
https://github.com/notifications/unsubscribe-auth/AA2O5MPVNO6STKOVZWXBQ6T2I3F5PAVCNFSM6AAAAABUH3AAC6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKNRZGYYTIMRYGU>
> .
> You are receiving this because you were mentioned.Message ID:
> ***@***.***>
>
—
Reply to this email directly, view it on GitHub
<#6 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVTQDZJBMWTVWZHWOOT2I5WZVAVCNFSM6AAAAABUH3AAC6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKNZQGEYTKNRRGM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
I can't think of any application the simpler implementation wouldn't handle. However, I'm sure someone else might. Regardless, I got the custom classes working using implicit transposition via a |
I'd like to add a
Matrix
class to complementVector
.A key design question is what this needs to support. In particular: do we need to support multiple axes that correspond to flattened pytrees, or is only a single axis enough?
If we only need to support a single "tree axis", then most
Matrix
operations can be implemented essentially by callingvmap
on aVector
, and the implementation only needs to keep track of whether the "tree axis" on the underlying pytree is at the start or the end. This would suffice for use-cases like implementing L-BFGS or GMRES, which keep track of some fixed number of state vectors in the form of a matrix.In contrast, multiple "tree axes" would be required to fully support use cases where both the inputs and outputs of a linear map correspond to (possible different) pytrees. For example, consider the outputs of
jax.jacobian
on a pytree -> pytree function. Here the implemention would need to be more complex to keep track of the separate tree definitions for inputs/outputs, similar to my first attempt at implementing a tree vectorizing transformation: jax-ml/jax#3263.My inclination is to only implement the "single tree-axis" version of
matrix
, which the reasoning being that it suffices to implement most "efficient" numerical algorithms on large-scale inputs, which cannot afford to use O(n^2) memory. On the other hand, it does preclude the interesting use-case of using tree-math to implementjax.jacobian
(and variations).The text was updated successfully, but these errors were encountered: