Skip to content

Commit

Permalink
support multiple inputs that require grad
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 26, 2021
1 parent fdf94c1 commit 196f6fe
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## jax2torch

Use Jax functions in Pytorch with DLPack, as outlined <a href="https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9">in a gist</a> by <a href="https://github.com/mattjj">@mattjj</a>. Right now only supports one tensor input (with optional non-tensor input arguments) to one tensor output, for the purposes of <a href="https://github.com/spetti/SMURF">differentiable alignment</a>.
Use Jax functions in Pytorch with DLPack, as outlined <a href="https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9">in a gist</a> by <a href="https://github.com/mattjj">@mattjj</a>. The repository was made for the purposes of making the <a href="https://github.com/spetti/SMURF">differentiable alignment</a> work here interoperable with Pytorch.

## Install

Expand Down
5 changes: 3 additions & 2 deletions jax2torch/jax2torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def backward(ctx, *grad_args):
grad_args = tree_t2j(grad_args)
else:
grad_args = t2j(grad_args[0])
grads, *_ = ctx.fun_vjp(grad_args)
ret = tree_j2t(grads), *((None,) * (ctx.num_args - 1))
grads = ctx.fun_vjp(grad_args)
grads = tuple(map(lambda t: t if isinstance(t, jnp.ndarray) else None, grads))
ret = tree_j2t(grads)
return ret

sig = signature(fn)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'jax2torch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Jax 2 Torch',
author = 'Phil Wang',
Expand Down

0 comments on commit 196f6fe

Please sign in to comment.