Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of take from pytorch #1627

Merged
merged 10 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,20 @@ default NDArray get(NDIndex index) {
return getNDArrayInternal().getIndexer().get(this, index);
}

/**
* Returns a partial {@code NDArray}.
*
* @param index the boolean or int {@code NDArray} that indicates what to get
* @return the partial {@code NDArray}
*/
default NDArray get(NDArray index) {
if (index.getDataType() == DataType.BOOLEAN) {
return get(new NDIndex().addBooleanIndex(index));
} else {
return take(index);
}
}

/**
* Returns a partial {@code NDArray}.
*
Expand All @@ -535,16 +549,6 @@ default NDArray get(long... indices) {
return get(new NDIndex(indices));
}

/**
* Returns a partial {@code NDArray}.
*
* @param index the boolean {@code NDArray} that indicates what to get
* @return the partial {@code NDArray}
*/
default NDArray get(NDArray index) {
return get(new NDIndex().addBooleanIndex(index));
}

/**
* Returns a partial {@code NDArray} pointed by the indexed array. Given NDArray arr, NDArray
* idx, and long axis, the output is out_{ijk} = arr_{idx_{ijk}, j, k} if axis=0 or arr_{i,
Expand All @@ -556,6 +560,14 @@ default NDArray get(NDArray index) {
*/
NDArray gather(NDArray index, int axis);

/**
* Returns a partial {@code NDArray} pointed by the indexed array, according to linear indexing.
*
* @param index picks the elements of an NDArray to the same position as index
* @return the partial {@code NDArray} of the same shape as index
*/
NDArray take(NDArray index);

/**
* Returns a scalar {@code NDArray} corresponding to a single element.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException(UNSUPPORTED_MSG);
}

/** {@inheritDoc} */
@Override
public void set(Buffer data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray ndArray) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,15 @@ public NDArray gather(NDArray index, int axis) {
return JniUtils.gather(this, (PtNDArray) index, axis);
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
if (!(index instanceof PtNDArray)) {
throw new IllegalArgumentException("Only PtNDArray is supported.");
}
return JniUtils.take(this, (PtNDArray) index);
}

/** {@inheritDoc} */
@Override
public void copyTo(NDArray array) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,23 @@ public static void set(PtNDArray self, ByteBuffer data) {
}

public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim) {
if (index.getDataType() != DataType.INT64) {
index = index.toType(DataType.INT64, true);
}
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchGather(ndArray.getHandle(), index.getHandle(), dim, false));
}

public static PtNDArray take(PtNDArray ndArray, PtNDArray index) {
if (index.getDataType() != DataType.INT64) {
index = index.toType(DataType.INT64, true);
}
return new PtNDArray(
ndArray.getManager(),
PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle()));
}

public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim) {
Shape indexShape = index.getShape();
Shape ndShape = ndArray.getShape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ native void torchIndexPut(

native long torchGather(long handle, long index, long dim, boolean sparseGrad);

native long torchTake(long handle, long index);

native long torchMaskedSelect(long handle, long maskHandle);

native void torchMaskedPut(long handle, long valueHandle, long maskHandle);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchGather(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchTake(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jindex_handle) {
API_BEGIN()
const auto* tensor_ptr = reinterpret_cast<torch::Tensor*>(jhandle);
const auto* index_ptr = reinterpret_cast<torch::Tensor*>(jindex_handle);
const auto* result_ptr = new torch::Tensor(tensor_ptr->take(*index_ptr));
return reinterpret_cast<uintptr_t>(result_ptr);
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchMaskedSelect(
JNIEnv* env, jobject jthis, jlong jhandle, jlong jmasked_handle) {
API_BEGIN()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,12 @@ public NDArray gather(NDArray index, int axis) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public void attach(NDManager manager) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.Shape;
import ai.djl.testing.TestRequirements;
import org.testng.Assert;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -55,20 +54,26 @@ public void testPick() {

@Test
public void testGather() {
// Currently in windows gradle cannot find all the engines.
// In the dependencies, changing runtimeOnly to api however will remedy the problem.
// TODO: remove this when gradle problem is fixed.
TestRequirements.notWindows();
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(20f).reshape(-1, 4);
long[] idx = {0, 0, 2, 1, 1, 2};
NDArray index = manager.create(idx, new Shape(3, 2));
NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2));
NDArray actual = arr.gather(index, 1);
NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2));
Assert.assertEquals(actual, expected);
}
}

@Test
public void testTake() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(1, 7f).reshape(-1, 3);
NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2));
NDArray actual = arr.take(index);
NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));
Assert.assertEquals(actual, expected);
}
}

@Test
public void testGet() {
try (NDManager manager = NDManager.newBaseManager()) {
Expand Down Expand Up @@ -112,6 +117,13 @@ public void testGet() {
NDArray bool = manager.create(new boolean[] {true, false});
expected = manager.arange(5).reshape(1, 5);
Assert.assertEquals(original.get(bool), expected);

// get from int array
original = manager.arange(1, 7f).reshape(-1, 3);
NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2));
NDArray actual = original.get(index);
expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));
Assert.assertEquals(actual, expected);
}
}

Expand Down