Skip to content

Commit

Permalink
take care of functions that output multiple tensors (say auxiliary lo…
Browse files Browse the repository at this point in the history
…sses)
  • Loading branch information
lucidrains committed Oct 26, 2021
1 parent c61d8d7 commit fdf94c1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions jax2torch/jax2torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ def forward(ctx, *args):
return tree_j2t(y_)

@staticmethod
def backward(ctx, grad_y):
grads, *_ = ctx.fun_vjp(t2j(grad_y))
def backward(ctx, *grad_args):
if len(grad_args) > 1:
grad_args = tree_t2j(grad_args)
else:
grad_args = t2j(grad_args[0])
grads, *_ = ctx.fun_vjp(grad_args)
ret = tree_j2t(grads), *((None,) * (ctx.num_args - 1))
return ret

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'jax2torch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Jax 2 Torch',
author = 'Phil Wang',
Expand Down

0 comments on commit fdf94c1

Please sign in to comment.