Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix support #6

Open
shoyer opened this issue Dec 29, 2021 · 11 comments
Open

Matrix support #6

shoyer opened this issue Dec 29, 2021 · 11 comments

Comments

@shoyer
Copy link
Member

shoyer commented Dec 29, 2021

I'd like to add a Matrix class to complement Vector.

A key design question is what this needs to support. In particular: do we need to support multiple axes that correspond to flattened pytrees, or is only a single axis enough?

If we only need to support a single "tree axis", then most Matrix operations can be implemented essentially by calling vmap on a Vector, and the implementation only needs to keep track of whether the "tree axis" on the underlying pytree is at the start or the end. This would suffice for use-cases like implementing L-BFGS or GMRES, which keep track of some fixed number of state vectors in the form of a matrix.

In contrast, multiple "tree axes" would be required to fully support use cases where both the inputs and outputs of a linear map correspond to (possible different) pytrees. For example, consider the outputs of jax.jacobian on a pytree -> pytree function. Here the implemention would need to be more complex to keep track of the separate tree definitions for inputs/outputs, similar to my first attempt at implementing a tree vectorizing transformation: jax-ml/jax#3263.

My inclination is to only implement the "single tree-axis" version of matrix, which the reasoning being that it suffices to implement most "efficient" numerical algorithms on large-scale inputs, which cannot afford to use O(n^2) memory. On the other hand, it does preclude the interesting use-case of using tree-math to implement jax.jacobian (and variations).

@shoyer
Copy link
Member Author

shoyer commented Dec 30, 2021

@Sohl-Dickstein seems to be in the camp of wanting support for multiple tree axes, to compute things like Jacobians, Hessians and covariances/correlations between pytrees.

@shoyer shoyer mentioned this issue Dec 30, 2021
@patrick-kidger
Copy link

Weighing in here at @shoyer's request --

I'd put myself in the multiple-tree-axes camp as well, I think. (I am saying that primarily as an end user rather than a developer for the feature, though... !)

This starts to open up notions of "labelled tree axes". For example, performing the pytree-axes equivalent of jax.vmap(jax.vmap(operator.mul, in_axes=(0, None)), in_axes=(None, 0)) (an outer product), in which two different BatchTraces interact.

For a thorny reference problem in which this kind of stuff might get quite useful (tree-matrix vs tree-ndarray or otherwise), I'd suggest the Ito version of Milstein's method available here:

https://github.com/patrick-kidger/diffrax/blob/10b652e1d91518ac182e8d832ff309f7c199a9a0/diffrax/solver/milstein.py#L104

This is a pretty tricky implementation! It's very heavily annotated with comments descrbing the various tree-axes, normal-axes, and the way in which they interact.

@geoff-davis
Copy link

+1 to the @Sohl-Dickstein use case. Some more detail of where this would be handy: I recently needed to invert a Hessian of a function that took a pytree as its argument. The headache I ran into was that when I used jax.jacfwd(jax.jacrev(f))(x) to compute the Hessian, I got it as a pytree of pytrees, which turned out to be pretty complicated to flatten. It would be nice to be able to either transform a pytree of pytrees to and from a matrix of floats or to be able to perform matrix operations directly on the pytree of pytrees.

@njwfish
Copy link

njwfish commented May 18, 2023

Just wanted to chime in and say that I'd love this feature, and for my use cases (which are primarily about numerical solvers for non-convex problems) a single axis is all I'd need, though I'm sure I'd find uses in multi-axis implementation if that does get developed.

@patrick-kidger
Copy link

So it's not a documented feature, but Equinox actually has a tree-math like sublibrary built-in, which can be used to do this kind of multi-axis stuff.

To set the scene, here is how it is used just to broadcast vector operations together:

from equinox.internal import ω

vector1 = [0, 1, (2, 3)]
vector2 = [4, 5, (6, 7)]
summed = (ω(vector1) + ω(vector2)).ω
# Alternate notation; I prefer this when doing pure arithmetic:
summed = (vector1**ω + vector2**ω).ω
print(summed)  # [4, 6, (8, 10)]

But with a bit of thinking you can nest these to accomplish higher-order operations:

# matrix has shape (2, 3)
matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])])
# vector has shape (3,)
vector = ω([6, 7, 8])
# product (2, 3) @ (3,) -> (2,)     ("call" applies the specified function to every leaf of its pytree)
matvec = matrix.call(lambda row: sum((row * vector).ω))
# unwrap
matvec = matvec.ω
print(matvec)  # [23, 86]

The reason this works is that ω is not a PyTree. This means that matrix = ω([ω([0, 1, 2]), ω([3, 4, 5])]) doesn't have the outer ω looking inside the inner ωs. (I believe tree-math's Vector is a PyTree and that the same trick wouldn't work in this library, though.)

Conversely, this does mean that you musn't pass ω objects across JIT/grad/etc. API boundaries. (Whilst you can with tree-math.) ω is only meant to be used as a convenient syntax with the bounds of a single function.

@shoyer
Copy link
Member Author

shoyer commented May 19, 2023

I do still think matrix support would be awesome to have, and I actually had a use-cases for this just last week.

That said, at this point I'm relatively unlikely to work on it. It somebody else wants to give this a try that would be very welcome!

@deasmhumhna
Copy link

deasmhumhna commented Dec 27, 2024

@geoff-davis, your statement about "pytrees of pytrees" gave me to idea to formulate this snippet:

from operator import add
from functools import partial

def inner(x: tm.Vector, y: tm.Vector, x_row_ndim) -> tm.Vector:

    def product_sum(x, y):
        return (x.reshape(x.shape + (1,) * (y.ndim - x.ndim + x_row_ndim)) * y).sum(axis=range(x_row_ndim, x.ndim))

    def product_sum_map(x, y: tm.Vector):
        return tm.Vector(jax.tree.map(partial(product_sum, x), y.tree))
    
    z = jax.tree.map(product_sum_map, x.tree, y.tree, is_leaf=lambda x: isinstance(x, jnp.ndarray) or isinstance(x, tm.Vector))
    return jax.tree.reduce(add, z, is_leaf=lambda x: isinstance(x, tm.Vector))

def outer(a: tm.Vector, b: tm.Vector, a_row_ndim: tm.Vector) -> tm.Vector:
    in_leaves, a_outer_structure = jax.tree.flatten(a.tree, is_leaf=lambda x: isinstance(x, tm.Vector))
    row_ndims, _ = jax.tree.flatten(a_row_ndim.tree)
    out_leaves = []
    for in_leaf, row_ndim in zip(in_leaves, row_ndims):
        out_leaf = inner(in_leaf, b, row_ndim)
        out_leaves.append(out_leaf)
    return tm.Vector(jax.tree.unflatten(a_outer_structure, out_leaves))

You're limited to Vector/VectorMixin objects that only contain other kinds of non-Vector pytrees, which likely cover most needs, but more generality would be nice. Also, you need to explicitly specify the batching dimension (row_ndim), and I see no way around this. You could create a Matrix class instance from any qualifying VectorMixin of VectorMixins and a VectorMixin of batching dimensions with the same outer tree structure.

Does JAX have a way of accessing a leaf/subtree by its KeyPath? If so, you could convert general pytrees of pytrees to a Matrix as long as we specify both outer (row) and inner (column) tree structures, probably by specifying the row and column batching dimensions, for the sake of symmetry.

Full code with tests: https://colab.research.google.com/drive/1AE_uTA0XuScd0TQdbN4v-Lzx8Dcrq8JX?usp=sharing

Update: I wrote a general version using a recursive bread-first search to flatten the input VectorMixin based on a guide tree. It works great! Now I need to wrap it all in a MatrixMixin class that implements __matmul__. I think this should work for matvecs as well.

Update: I am ready to submit a pull request. The only thing left might be to have the product of MatrixMixin and a VectorMixin equal a VectorMixin, at least for wrapped objects. But since MatrixMixin inherits from VectorMixin it has all the functionality, plus access to a more general __matmul__ operator.

@shoyer
Copy link
Member Author

shoyer commented Jan 3, 2025

@deasmhumhna this looks very cool!

Could you please share a docstring explaining the precise data model for Matrix/MatrixMixin? It would be nice to include a few simple examples (also to help me understand exactly what you've built), since it is a little more complex than Vector.

Right now, it looks like your implementation covers the "tree of trees" use case well, e.g., for representing Hessian matrices. The other use that is important to handle (possibly more common) is the "stack of trees" case, e.g., for representing a finite list of k vectors, such as the memory terms in L-BFGS or a multi-step ODE solver.

In principle I think the "stack of trees" should be a special case of your more general Matrix, with a trivial outer_treedef and row_ndims=1, but it would good to verify this, and also perhaps worthwhile to add some helpers/constructors to ensure this works well. For example, it would be nice if matrix[i, :] returned a vector, and if there was a method to create a matrix of all zeros with k rows matching the structure of a given vector.

I agree that it would be a good idea for matrix @ vector to result in Vector. This would be a good idea to include in the initial pull request, because otherwise it will be a breaking change at some point in the future.

@deasmhumhna
Copy link

deasmhumhna commented Jan 4, 2025 via email

@shoyer
Copy link
Member Author

shoyer commented Jan 4, 2025 via email

@deasmhumhna
Copy link

I can't think of any application the simpler implementation wouldn't handle. However, I'm sure someone else might. Regardless, I got the custom classes working using implicit transposition via a trans flag. I preserved the original method of structure transposition within the Matrix class and constrained trans to always be False. The custom class code feels somewhat inelegant and fragile compared to the specific case, and the way I'm forcing the result of MatrixMixin and VectorMixin products to be VectorMixin also seems prone to bugs, but it does work. You can see the updates in the Colab. I'll try to finish the pull request by sometime next month.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants