Skip to content

Commit

Permalink
Support of take from pytorch (#1627)
Browse files Browse the repository at this point in the history
* take_dev

* assertion_err

* Keep test EngineAgnostic

* testTake skipping windows is removed

* combine get(NDArray(BOOLEAN) and NDArray(INT))

* add test

* add_test

* comb_get

* bug

* DataType check
  • Loading branch information
KexinFeng authored May 12, 2022
1 parent 2e89a74 commit c9d5412
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 12 deletions.
32 changes: 22 additions & 10 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,20 @@ default NDArray get(NDIndex index) {
return getNDArrayInternal().getIndexer().get(this, index);
}

/**
* Returns a partial {@code NDArray}.
*
* @param index the boolean or int {@code NDArray} that indicates what to get
* @return the partial {@code NDArray}
*/
default NDArray get(NDArray index) {
if (index.getDataType() == DataType.BOOLEAN) {
return get(new NDIndex().addBooleanIndex(index));
} else {
return take(index);
}
}

/**
* Returns a partial {@code NDArray}.
*
Expand All @@ -535,16 +549,6 @@ default NDArray get(long... indices) {
return get(new NDIndex(indices));
}

/**
* Returns a partial {@code NDArray}.
*
* @param index the boolean {@code NDArray} that indicates what to get
* @return the partial {@code NDArray}
*/
default NDArray get(NDArray index) {
return get(new NDIndex().addBooleanIndex(index));
}

/**
* Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray
* idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i,
Expand All @@ -556,6 +560,14 @@ default NDArray get(NDArray index) {
*/
NDArray gather(NDArray index, int axis);

/**
* Returns a partial {@code NDArray} pointed by the indexed array, according to linear indexing.
*
* @param index picks the elements of an NDArray to the same position as index
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray take(NDArray index);

/**
* Returns a scalar {@code NDArray} corresponding to a single element.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray ndArray) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ public NDArray gather(NDArray index, int axis) {
return JniUtils.gather(this, (PtNDArray) index, axis);
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
if (!(index instanceof PtNDArray)) {
throw new IllegalArgumentException("Only PtNDArray is supported.");
}
return JniUtils.take(this, (PtNDArray) index);
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray array) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,23 @@ public static void set(PtNDArray self, ByteBuffer data) {
}

public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) {
if (index.getDataType() != DataType.INT64) {
index = index.toType(DataType.INT64, true);
}
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false));
}

public static PtNDArray take(PtNDArray ndArray, PtNDArray index) {
if (index.getDataType() != DataType.INT64) {
index = index.toType(DataType.INT64, true);
}
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle()));
}

public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) {
Shape indexShape = index.getShape();
Shape ndShape = ndArray.getShape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ native void torchIndexPut(

native long torchGather(long handle, long index, long dim, boolean sparseGrad);

native long torchTake(long handle, long index);

native long torchMaskedSelect(long handle, long maskHandle);

native void torchMaskedPut(long handle, long valueHandle, long maskHandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchGather(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTake(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jindex_handle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* index_ptr = reinterpret_cast<torch::Tensor*>(jindex_handle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->take(*index_ptr));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMaskedSelect(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jmasked_handle) {
API_BEGIN()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public void attach(NDManager manager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,24 @@ public void testPick() {
public void testGather() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(20f).reshape(-1, 4);
long[] idx = {0, 0, 2, 1, 1, 2};
NDArray index = manager.create(idx, new Shape(3, 2));
NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2));
NDArray actual = arr.gather(index, 1);
NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2));
Assert.assertEquals(actual, expected);
}
}

@Test
public void testTake() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(1, 7f).reshape(-1, 3);
NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2));
NDArray actual = arr.take(index);
NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));
Assert.assertEquals(actual, expected);
}
}

@Test
public void testGet() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down Expand Up @@ -107,6 +117,13 @@ public void testGet() {
NDArray bool = manager.create(new boolean[] {true, false});
expected = manager.arange(5).reshape(1, 5);
Assert.assertEquals(original.get(bool), expected);

// get from int array
original = manager.arange(1, 7f).reshape(-1, 3);
NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2));
NDArray actual = original.get(index);
expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));
Assert.assertEquals(actual, expected);
}
}

Expand Down

0 comments on commit c9d5412

Please sign in to comment.