Skip to content

Commit

Permalink
make sure dlpack inits correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 26, 2021
1 parent e9bcd4f commit ce785c7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions jax2torch/jax2torch.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9

import torch
from jax import dlpack as jax_dlpack
from torch.utils import dlpack as torch_dlpack

import jax
import jax.numpy as jnp
from jax.tree_util import tree_map

def j2t(x_jax):
x_torch = torch.utils.dlpack.from_dlpack(jax.dlpack.to_dlpack(x_jax))
x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
return x_torch

def t2j(x_torch):
x_torch = x_torch.contiguous()
x_jax = jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(x_torch))
x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
return x_jax

def tree_t2j(x_torch):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'jax2torch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Jax 2 Torch',
author = 'Phil Wang',
Expand Down

0 comments on commit ce785c7

Please sign in to comment.