Skip to content

Commit

Permalink
take_dev
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed May 6, 2022
1 parent 46d3479 commit dc1b7a8
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 2 deletions.
8 changes: 8 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,14 @@ default NDArray get(NDArray index) {
*/
NDArray gather(NDArray index, int axis);

/**
* Returns a partial {@code NDArray} pointed by the indexed array, according to linear indexing.
*
* @param index picks the elements of an NDArray to the same position as index
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray take(NDArray index);

/**
* Returns a scalar {@code NDArray} corresponding to a single element.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray ndArray) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ public NDArray gather(NDArray index, int axis) {
return JniUtils.gather(this, (PtNDArray) index, axis);
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
if (!(index instanceof PtNDArray)) {
throw new IllegalArgumentException("Only PtNDArray is supported.");
}
return JniUtils.take(this, (PtNDArray) index);
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray array) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,12 @@ public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) {
PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false));
}

public static PtNDArray take(PtNDArray ndArray, PtNDArray index) {
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle()));
}

public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) {
Shape indexShape = index.getShape();
Shape ndShape = ndArray.getShape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ native void torchIndexPut(

native long torchGather(long handle, long index, long dim, boolean sparseGrad);

native long torchTake(long handle, long index);

native long torchMaskedSelect(long handle, long maskHandle);

native void torchMaskedPut(long handle, long valueHandle, long maskHandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchGather(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTake(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jindex_handle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* index_ptr = reinterpret_cast<torch::Tensor*>(jindex_handle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->take(*index_ptr));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMaskedSelect(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jmasked_handle) {
API_BEGIN()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public void attach(NDManager manager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.integration.tests.ndarray;

import ai.djl.engine.Engine;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
Expand Down Expand Up @@ -61,14 +62,25 @@ public void testGather() {
TestRequirements.notWindows();
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(20f).reshape(-1, 4);
long[] idx = {0, 0, 2, 1, 1, 2};
NDArray index = manager.create(idx, new Shape(3, 2));
NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2));
NDArray actual = arr.gather(index, 1);
NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2));
Assert.assertEquals(actual, expected);
}
}

@Test
public void testTake() {
Engine engine = Engine.getEngine("PyTorch");
try (NDManager manager = engine.newBaseManager()) {
NDArray arr = manager.arange(6f).reshape(-1, 3);
NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2));
NDArray actual = arr.take(index);
NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));
Assert.assertEquals(actual, expected);
}
}

@Test
public void testGet() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down

0 comments on commit dc1b7a8

Please sign in to comment.