From d35ee3f8c360e1fd9bceb7b826bf9175bd81d9b5 Mon Sep 17 00:00:00 2001 From: Kexin Date: Wed, 15 Jun 2022 22:07:50 -0700 Subject: [PATCH 01/15] mixed index getter on pytorch draft --- .../ai/djl/ndarray/index/NDArrayIndexer.java | 29 +---- .../java/ai/djl/ndarray/index/NDIndex.java | 17 ++- .../ai/djl/ndarray/index/dim/NDIndexNone.java | 23 ++++ .../ai/djl/ndarray/index/dim/NDIndexPick.java | 16 +-- .../ndarray/index/full/NDIndexFullPick.java | 2 +- .../ai/djl/mxnet/engine/MxNDArrayIndexer.java | 35 ++++++ .../djl/pytorch/engine/PtNDArrayIndexer.java | 7 ++ .../java/ai/djl/pytorch/jni/JniUtils.java | 51 +++++++++ .../ai/djl/pytorch/jni/PyTorchLibrary.java | 18 +++ ...i_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 105 +++++++++++++++++- .../tensorflow/engine/TfNDArrayIndexer.java | 7 ++ 11 files changed, 263 insertions(+), 47 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java diff --git a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java index 47994b22cda..d71a9b9acd7 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java @@ -20,7 +20,6 @@ import ai.djl.ndarray.index.full.NDIndexFullSlice; import java.util.List; -import java.util.Optional; /** A helper class for {@link NDArray} implementations for operations with an {@link NDIndex}. */ public abstract class NDArrayIndexer { @@ -50,33 +49,7 @@ public abstract class NDArrayIndexer { * @param index the index to get * @return the subarray */ - public NDArray get(NDArray array, NDIndex index) { - if (index.getRank() == 0 && array.getShape().isScalar()) { - return array.duplicate(); - } - - // use booleanMask for NDIndexBooleans case - List indices = index.getIndices(); - if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) { - if (indices.size() != 1) { - throw new IllegalArgumentException( - "get() currently doesn't support more that one boolean NDArray"); - } - return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex()); - } - - Optional fullPick = NDIndexFullPick.fromIndex(index, array.getShape()); - if (fullPick.isPresent()) { - return get(array, fullPick.get()); - } - - Optional fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape()); - if (fullSlice.isPresent()) { - return get(array, fullSlice.get()); - } - throw new UnsupportedOperationException( - "get() currently supports all, fixed, and slices indices"); - } + public abstract NDArray get(NDArray array, NDIndex index); /** * Sets the values of the array at the fullSlice with an array. diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index 5eef36e456a..f09f9564a1c 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -13,12 +13,7 @@ package ai.djl.ndarray.index; import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.index.dim.NDIndexAll; -import ai.djl.ndarray.index.dim.NDIndexBooleans; -import ai.djl.ndarray.index.dim.NDIndexElement; -import ai.djl.ndarray.index.dim.NDIndexFixed; -import ai.djl.ndarray.index.dim.NDIndexPick; -import ai.djl.ndarray.index.dim.NDIndexSlice; +import ai.djl.ndarray.index.dim.*; import ai.djl.ndarray.types.DataType; import java.util.ArrayList; @@ -50,7 +45,7 @@ public class NDIndex { /* Android regex requires escape } char as well */ private static final Pattern ITEM_PATTERN = Pattern.compile( - "(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})"); + "(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})|None"); private int rank; private List indices; @@ -193,7 +188,7 @@ public List getIndices() { * @return the updated {@link NDIndex} * @see #NDIndex(String, Object...) */ - public final NDIndex addIndices(String indices, Object... args) { + public final void addIndices(String indices, Object... args) { String[] indexItems = indices.split(","); rank += indexItems.length; int argIndex = 0; @@ -215,7 +210,6 @@ public final NDIndex addIndices(String indices, Object... args) { if (argIndex != args.length) { throw new IllegalArgumentException("Incorrect number of index arguments"); } - return this; } /** @@ -335,6 +329,11 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) { if (!m.matches()) { throw new IllegalArgumentException("Invalid argument index: " + indexItem); } + // None + if (indexItem.equals("None")) { + indices.add(new NDIndexNone()); + return argIndex; + } // "*" case String star = m.group(1); if (star != null) { diff --git a/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java new file mode 100644 index 00000000000..1cfb32dcda6 --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java @@ -0,0 +1,23 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.index.dim; + +/** An {@code NDIndexElement} to return all values in a particular dimension. */ +public class NDIndexNone implements NDIndexElement { + + /** {@inheritDoc} */ + @Override + public int getRank() { + return 1; + } +} diff --git a/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexPick.java b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexPick.java index 4786353dcea..ccd13b20ed7 100644 --- a/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexPick.java +++ b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexPick.java @@ -17,15 +17,15 @@ /** An {@link NDIndexElement} that gets elements by index in the specified axis. */ public class NDIndexPick implements NDIndexElement { - private NDArray indices; + private NDArray index; /** * Constructs a pick. * - * @param indices the indices to pick + * @param index the index to pick */ - public NDIndexPick(NDArray indices) { - this.indices = indices; + public NDIndexPick(NDArray index) { + this.index = index; } /** {@inheritDoc} */ @@ -35,11 +35,11 @@ public int getRank() { } /** - * Returns the indices to pick. + * Returns the index to pick. * - * @return the indices to pick + * @return the index to pick */ - public NDArray getIndices() { - return indices; + public NDArray getIndex() { + return index; } } diff --git a/api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java b/api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java index 0136e895de5..30d67bbba00 100644 --- a/api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java +++ b/api/src/main/java/ai/djl/ndarray/index/full/NDIndexFullPick.java @@ -53,7 +53,7 @@ public static Optional fromIndex(NDIndex index, Shape target) { axis++; } else if (el instanceof NDIndexPick) { if (fullPick == null) { - fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndices(), axis); + fullPick = new NDIndexFullPick(((NDIndexPick) el).getIndex(), axis); } else { // Don't support multiple picks throw new UnsupportedOperationException( diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java index 5ac0274844e..df7d404a745 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java @@ -15,10 +15,15 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDArrayIndexer; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.index.dim.NDIndexBooleans; +import ai.djl.ndarray.index.dim.NDIndexElement; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; import ai.djl.ndarray.types.Shape; +import java.util.List; +import java.util.Optional; import java.util.Stack; /** The {@link NDArrayIndexer} used by the {@link MxNDArray}. */ @@ -61,6 +66,36 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { return result; } + /** {@inheritDoc} */ + @Override + public NDArray get(NDArray array, NDIndex index) { + if (index.getRank() == 0 && array.getShape().isScalar()) { + return array.duplicate(); + } + + // use booleanMask for NDIndexBooleans case + List indices = index.getIndices(); + if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) { + if (indices.size() != 1) { + throw new IllegalArgumentException( + "get() currently doesn't support more that one boolean NDArray"); + } + return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex()); + } + + Optional fullPick = NDIndexFullPick.fromIndex(index, array.getShape()); + if (fullPick.isPresent()) { + return get(array, fullPick.get()); + } + + Optional fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape()); + if (fullSlice.isPresent()) { + return get(array, fullSlice.get()); + } + throw new UnsupportedOperationException( + "get() currently supports all, fixed, and slices indices in MXNet engine"); + } + /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index 9ffd96c127b..378c2d114ee 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; +import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; @@ -49,6 +50,12 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { } } + /** {@inheritDoc} */ + @Override + public NDArray get(NDArray array, NDIndex index) { + return JniUtils.indexAdv(manager.from(array), index); + } + /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index ebd5a998dec..433094b3660 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -14,6 +14,8 @@ import ai.djl.Device; import ai.djl.ndarray.NDList; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.index.dim.*; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; @@ -337,6 +339,54 @@ public static PtNDArray index( ndArray.getHandle(), minIndices, maxIndices, stepIndices)); } + public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) { + List indices = index.getIndices(); + long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); + ListIterator it = indices.listIterator(); + while (it.hasNext()) { + if (it.nextIndex() == index.getEllipsisIndex()) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); + } + + NDIndexElement elem = it.next(); + if (elem instanceof NDIndexNone) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false); + } else if (elem instanceof NDIndexSlice) { + Long min = ((NDIndexSlice) elem).getMin(); + Long max = ((NDIndexSlice) elem).getMax(); + Long step = ((NDIndexSlice) elem).getStep(); + int null_slice_bin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); + // null_slice_bin encodes whether (min, max) is null: + // is_null == 1, ! is_null == 0; + // 0b11 == 3, 0b10 = 2, ... + PyTorchLibrary.LIB.torchIndexAppendSlice( + torchIndexHandle, + min == null ? 0 : min, + max == null ? 0 : max, + step == null ? 1 : step, + null_slice_bin); + } else if (elem instanceof NDIndexAll) { + PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3); + } else if (elem instanceof NDIndexFixed) { + PyTorchLibrary.LIB.torchIndexAppendFixed( + torchIndexHandle, ((NDIndexFixed) elem).getIndex()); + } else if (elem instanceof NDIndexBooleans) { + PtNDArray index_arr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); + PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, index_arr.getHandle()); + } else if (elem instanceof NDIndexPick) { + PtNDArray index_arr = (PtNDArray) ((NDIndexPick) elem).getIndex(); + PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, index_arr.getHandle()); + } + } + if (indices.size() == index.getEllipsisIndex()) { + PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, true); + } + + return new PtNDArray( + ndArray.getManager(), + PyTorchLibrary.LIB.torchIndexReturn(ndArray.getHandle(), torchIndexHandle)); + } + public static void indexSet( PtNDArray ndArray, PtNDArray value, @@ -365,6 +415,7 @@ public static PtNDArray take(PtNDArray ndArray, PtNDArray index) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } + System.out.println(PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 7bee9f7abd1..0262cf6f2e1 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -12,9 +12,12 @@ */ package ai.djl.pytorch.jni; +import ai.djl.ndarray.index.dim.NDIndexElement; + import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; +import java.util.List; import java.util.Set; /** A class containing utilities to interact with the PyTorch Engine's JNI layer. */ @@ -185,6 +188,8 @@ native long torchFromBlob( native long torchIndex(long handle, long[] minIndices, long[] maxIndices, long[] stepIndices); + native long torchIndexAdv(long handle, List indices); + native void torchIndexPut( long handle, long valueHandle, @@ -600,4 +605,17 @@ native void sgdUpdate( native long torchNorm(long handle, int ord, long[] axis, boolean keepDims); native long torchNonZeros(long handle); + + native long torchIndexInit(int size); + + native long torchIndexReturn(long handle, long torchIndexHandle); + + native void torchIndexAppendNoneEllipsis(long torchIndexHandle, boolean is_ellipsis); + + native void torchIndexAppendSlice( + long torchIndexHandle, long min, long max, long step, int null_slice_bin); + + native void torchIndexAppendFixed(long torchIndexHandle, long idx); + + native void torchIndexAppendArray(long torchIndexHandle, long arrayHandle); } diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index ea48f69f199..d9148ccfc5d 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -116,11 +116,114 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndex(JNIEnv API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); auto indices = utils::CreateTensorIndex(env, jmin_indices, jmax_indices, jstep_indices); - const auto* result_ptr = new torch::Tensor(tensor_ptr->index(indices)); return reinterpret_cast(result_ptr); API_END_RETURN() } +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAdv(JNIEnv* env, jobject jthis, jlong jhandle, jobject index) { + API_BEGIN() + std::cout << "Reached here!" << std::endl; + const auto* tensor_ptr = reinterpret_cast(jhandle); + using namespace std; + + jclass java_util_ArrayList = static_cast(env->FindClass("java/util/ArrayList")); + cout << "temp check:" << java_util_ArrayList << endl; + jmethodID java_util_ArrayList_get = env->GetMethodID(java_util_ArrayList, "get", "(I)Ljava/lang/Object;"); + jmethodID java_util_ArrayList_size = env->GetMethodID(java_util_ArrayList, "size", "()I"); + cout << "Give me correct size() methodID:" << java_util_ArrayList_get << endl; + cout << "jclass: " << java_util_ArrayList << endl; + + jint size = 0; + cout << "size:" << size << endl; +// size = env->CallIntMethod(java_util_ArrayList, java_util_ArrayList_size); + jobject index_element = env->CallObjectMethod(java_util_ArrayList, java_util_ArrayList_get, 0); +// cout << "size" << size << endl; + cout << index_element << endl; + + cout << "Dev ends here!" << endl; + return reinterpret_cast(tensor_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexInit(JNIEnv* env, jobject jthis, jint jsize) { + API_BEGIN() + //const auto* tensor_ptr = reinterpret_cast(jhandle); + using namespace std; + using namespace torch::indexing; + + std::vector *index_ptr = new std::vector; + index_ptr->reserve(jsize); + + index_ptr->emplace_back(Slice(0, 2)); + cout << (*index_ptr)[0] << endl; + //cout << tensor_ptr->index((*index_ptr)[0]) << endl; + + return reinterpret_cast(index_ptr); + API_END_RETURN() +} + +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexReturn(JNIEnv* env, jobject jthis, + jlong jhandle, jlong jtorch_index_handle) { + API_BEGIN() + const auto* tensor_ptr = reinterpret_cast(jhandle); + using namespace std; + using namespace torch::indexing; + + std::cout << "IndexReturn:" << jtorch_index_handle << std::endl; + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + std::cout << "*index_ptr[0] " << sizeof(index_ptr->at(0)) << " " << index_ptr->at(0) << std::endl; + std::cout << "*index_ptr " << sizeof(*index_ptr) << " " << *index_ptr << std::endl; + cout << "output: " << tensor_ptr->index(index_ptr->at(0)) << endl; + torch::Tensor* ret_ptr = new torch::Tensor(tensor_ptr->index(index_ptr->at(0))); + return reinterpret_cast(ret_ptr); + API_END_RETURN() +} + +JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendNoneEllipsis(JNIEnv* env, jobject jthis, + jlong jtorch_index_handle, jboolean jis_ellipsis) { + API_BEGIN() + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + if (jis_ellipsis) { + index_ptr->emplace_back(torch::indexing::Ellipsis); + } else { + index_ptr->emplace_back(torch::indexing::None); + } + API_END() +} + +JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendSlice(JNIEnv* env, jobject jthis, + jlong jtorch_index_handle, jlong jmin, jlong jmax, jlong jstep, jint jnull_slice_bin) { + API_BEGIN() + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + if (jnull_slice_bin == 0) { + index_ptr->emplace_back(torch::indexing::Slice(jmin, jmax, jstep)); + } else if (jnull_slice_bin == 1) { + index_ptr->emplace_back(torch::indexing::Slice(jmin, torch::indexing::None, jstep)); + } else if (jnull_slice_bin == 2) { + index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, jmax, jstep)); + } else if (jnull_slice_bin == 3) { + index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, torch::indexing::None, jstep)); + } + API_END() +} + +JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendFixed(JNIEnv* env, jobject jthis, + jlong jtorch_index_handle, jlong jidx) { + API_BEGIN() + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + index_ptr->emplace_back((long) jidx); + API_END() +} + +JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendArray(JNIEnv* env, jobject jthis, + jlong jtorch_index_handle, jlong jarray) { +API_BEGIN() + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* array_ptr = reinterpret_cast(jarray); + index_ptr->emplace_back(*array_ptr); +API_END() +} + JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexPut(JNIEnv* env, jobject jthis, jlong jhandle, jlong jvalue_handle, jlongArray jmin_indices, jlongArray jmax_indices, jlongArray jstep_indices) { API_BEGIN() diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java index 73b3a4c5d54..3562a50dc7b 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java @@ -14,6 +14,7 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; +import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; @@ -56,6 +57,12 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { } } + /** {@inheritDoc} */ + @Override + public NDArray get(NDArray array, NDIndex index) { + throw new UnsupportedOperationException("Not implemented"); + } + /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { From 5902971e2f2cf1aa4e70807810c4fb155980fd17 Mon Sep 17 00:00:00 2001 From: Kexin Date: Thu, 16 Jun 2022 09:21:29 -0700 Subject: [PATCH 02/15] code cleaning --- .../ai/djl/mxnet/engine/MxNDArrayIndexer.java | 2 +- .../ai/djl/pytorch/jni/PyTorchLibrary.java | 2 -- ...i_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 25 ------------------- 3 files changed, 1 insertion(+), 28 deletions(-) diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java index df7d404a745..4465a9526ab 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java @@ -93,7 +93,7 @@ public NDArray get(NDArray array, NDIndex index) { return get(array, fullSlice.get()); } throw new UnsupportedOperationException( - "get() currently supports all, fixed, and slices indices in MXNet engine"); + "get() currently supports only all, fixed, and slices indices in MXNet engine"); } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 0262cf6f2e1..331c989781d 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -188,8 +188,6 @@ native long torchFromBlob( native long torchIndex(long handle, long[] minIndices, long[] maxIndices, long[] stepIndices); - native long torchIndexAdv(long handle, List indices); - native void torchIndexPut( long handle, long valueHandle, diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index d9148ccfc5d..1354607d284 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -120,31 +120,6 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndex(JNIEnv API_END_RETURN() } -JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAdv(JNIEnv* env, jobject jthis, jlong jhandle, jobject index) { - API_BEGIN() - std::cout << "Reached here!" << std::endl; - const auto* tensor_ptr = reinterpret_cast(jhandle); - using namespace std; - - jclass java_util_ArrayList = static_cast(env->FindClass("java/util/ArrayList")); - cout << "temp check:" << java_util_ArrayList << endl; - jmethodID java_util_ArrayList_get = env->GetMethodID(java_util_ArrayList, "get", "(I)Ljava/lang/Object;"); - jmethodID java_util_ArrayList_size = env->GetMethodID(java_util_ArrayList, "size", "()I"); - cout << "Give me correct size() methodID:" << java_util_ArrayList_get << endl; - cout << "jclass: " << java_util_ArrayList << endl; - - jint size = 0; - cout << "size:" << size << endl; -// size = env->CallIntMethod(java_util_ArrayList, java_util_ArrayList_size); - jobject index_element = env->CallObjectMethod(java_util_ArrayList, java_util_ArrayList_get, 0); -// cout << "size" << size << endl; - cout << index_element << endl; - - cout << "Dev ends here!" << endl; - return reinterpret_cast(tensor_ptr); - API_END_RETURN() -} - JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexInit(JNIEnv* env, jobject jthis, jint jsize) { API_BEGIN() //const auto* tensor_ptr = reinterpret_cast(jhandle); From cc36b48f6e65392745c307c7e08e655815181b54 Mon Sep 17 00:00:00 2001 From: Kexin Date: Thu, 16 Jun 2022 09:37:43 -0700 Subject: [PATCH 03/15] code cleaning 2 --- .../src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 331c989781d..6ba9b1838bf 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -12,12 +12,10 @@ */ package ai.djl.pytorch.jni; -import ai.djl.ndarray.index.dim.NDIndexElement; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; -import java.util.List; import java.util.Set; /** A class containing utilities to interact with the PyTorch Engine's JNI layer. */ From 478989c5c97fce9dd562d6253218721733f9efa6 Mon Sep 17 00:00:00 2001 From: Kexin Date: Thu, 16 Jun 2022 16:26:44 -0700 Subject: [PATCH 04/15] feed std::vector<> to tensor.index(ArrayRef<>) --- ...i_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 25 ++--- .../tests/ndarray/NDIndexTest.java | 91 ++++++++++--------- 2 files changed, 54 insertions(+), 62 deletions(-) diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index 1354607d284..8bae7bb772a 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -116,23 +116,15 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndex(JNIEnv API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); auto indices = utils::CreateTensorIndex(env, jmin_indices, jmax_indices, jstep_indices); + const auto* result_ptr = new torch::Tensor(tensor_ptr->index(indices)); return reinterpret_cast(result_ptr); API_END_RETURN() } JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexInit(JNIEnv* env, jobject jthis, jint jsize) { API_BEGIN() - //const auto* tensor_ptr = reinterpret_cast(jhandle); - using namespace std; - using namespace torch::indexing; - std::vector *index_ptr = new std::vector; index_ptr->reserve(jsize); - - index_ptr->emplace_back(Slice(0, 2)); - cout << (*index_ptr)[0] << endl; - //cout << tensor_ptr->index((*index_ptr)[0]) << endl; - return reinterpret_cast(index_ptr); API_END_RETURN() } @@ -141,15 +133,8 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexReturn( jlong jhandle, jlong jtorch_index_handle) { API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); - using namespace std; - using namespace torch::indexing; - - std::cout << "IndexReturn:" << jtorch_index_handle << std::endl; auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); - std::cout << "*index_ptr[0] " << sizeof(index_ptr->at(0)) << " " << index_ptr->at(0) << std::endl; - std::cout << "*index_ptr " << sizeof(*index_ptr) << " " << *index_ptr << std::endl; - cout << "output: " << tensor_ptr->index(index_ptr->at(0)) << endl; - torch::Tensor* ret_ptr = new torch::Tensor(tensor_ptr->index(index_ptr->at(0))); + torch::Tensor* ret_ptr = new torch::Tensor(tensor_ptr->index(*index_ptr)); return reinterpret_cast(ret_ptr); API_END_RETURN() } @@ -186,7 +171,11 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendFi jlong jtorch_index_handle, jlong jidx) { API_BEGIN() auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); - index_ptr->emplace_back((long) jidx); + index_ptr->emplace_back((int) jidx); + using namespace std; + cout << "DEBUG" << endl; + cout << index_ptr->at(0) << endl; + cout << index_ptr->size() << endl; API_END() } diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 6e4368760fe..e879293f54f 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -81,50 +81,53 @@ public void testGet() { NDArray original = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2)); Assert.assertEquals(original.get(new NDIndex()), original); - NDArray getAt = original.get(0); - NDArray expected = manager.create(new float[] {1f, 2f}); - Assert.assertEquals(getAt, expected); - - Assert.assertEquals(original.get("0,:"), expected); - Assert.assertEquals(original.get("0,*"), expected); - - NDArray getSlice = original.get("1:"); - expected = manager.create(new float[] {3f, 4f}, new Shape(1, 2)); - Assert.assertEquals(getSlice, expected); - - NDArray getStepSlice = original.get("1::2"); - Assert.assertEquals(getStepSlice, expected); - - original = manager.arange(120).reshape(2, 3, 4, 5); - NDArray getEllipsis = original.get("0,2, ... "); - expected = manager.arange(40, 60).reshape(4, 5); - Assert.assertEquals(getEllipsis, expected); - - getEllipsis = original.get("...,0:2,2"); - expected = - manager.create(new int[] {2, 7, 22, 27, 42, 47, 62, 67, 82, 87, 102, 107}) - .reshape(2, 3, 2); - Assert.assertEquals(getEllipsis, expected); - - getEllipsis = original.get("1,...,2,3:5:2"); - expected = manager.create(new int[] {73, 93, 113}).reshape(3, 1); - Assert.assertEquals(getEllipsis, expected); - - getEllipsis = original.get("..."); - Assert.assertEquals(getEllipsis, original); - - // get from boolean array - original = manager.arange(10).reshape(2, 5); - NDArray bool = manager.create(new boolean[] {true, false}); - expected = manager.arange(5).reshape(1, 5); - Assert.assertEquals(original.get(bool), expected); - - // get from int array - original = manager.arange(1, 7f).reshape(-1, 3); - NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); - NDArray actual = original.get(index); - expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); - Assert.assertEquals(actual, expected); + System.out.println("Engine is: " + manager.getEngine()); + + // NDArray getAt = original.get(0); + // NDArray expected = manager.create(new float[] {1f, 2f}); + // Assert.assertEquals(getAt, expected); + + // Assert.assertEquals(original.get("0,:"), expected); + // Assert.assertEquals(original.get("0,*"), expected); + // + // NDArray getSlice = original.get("1:"); + // expected = manager.create(new float[] {3f, 4f}, new Shape(1, 2)); + // Assert.assertEquals(getSlice, expected); + // + // NDArray getStepSlice = original.get("1::2"); + // Assert.assertEquals(getStepSlice, expected); + // + // original = manager.arange(120).reshape(2, 3, 4, 5); + // NDArray getEllipsis = original.get("0,2, ... "); + // expected = manager.arange(40, 60).reshape(4, 5); + // Assert.assertEquals(getEllipsis, expected); + // + // getEllipsis = original.get("...,0:2,2"); + // expected = + // manager.create(new int[] {2, 7, 22, 27, 42, 47, 62, 67, 82, 87, + // 102, 107}) + // .reshape(2, 3, 2); + // Assert.assertEquals(getEllipsis, expected); + // + // getEllipsis = original.get("1,...,2,3:5:2"); + // expected = manager.create(new int[] {73, 93, 113}).reshape(3, 1); + // Assert.assertEquals(getEllipsis, expected); + // + // getEllipsis = original.get("..."); + // Assert.assertEquals(getEllipsis, original); + // + // // get from boolean array + // original = manager.arange(10).reshape(2, 5); + // NDArray bool = manager.create(new boolean[] {true, false}); + // expected = manager.arange(5).reshape(1, 5); + // Assert.assertEquals(original.get(bool), expected); + // + // // get from int array + // original = manager.arange(1, 7f).reshape(-1, 3); + // NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + // NDArray actual = original.get(index); + // expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); + // Assert.assertEquals(actual, expected); } } From 39bd01262552b6a020bdc5d61d384bd491d1c798 Mon Sep 17 00:00:00 2001 From: Kexin Date: Thu, 16 Jun 2022 17:20:59 -0700 Subject: [PATCH 05/15] bug fixed --- .../java/ai/djl/ndarray/index/NDIndex.java | 11 ++- .../djl/pytorch/engine/PtNDArrayIndexer.java | 6 ++ ...i_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 4 - .../tests/ndarray/NDIndexTest.java | 91 +++++++++---------- 4 files changed, 58 insertions(+), 54 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index f09f9564a1c..d604d0a941d 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -13,7 +13,13 @@ package ai.djl.ndarray.index; import ai.djl.ndarray.NDArray; -import ai.djl.ndarray.index.dim.*; +import ai.djl.ndarray.index.dim.NDIndexAll; +import ai.djl.ndarray.index.dim.NDIndexBooleans; +import ai.djl.ndarray.index.dim.NDIndexElement; +import ai.djl.ndarray.index.dim.NDIndexFixed; +import ai.djl.ndarray.index.dim.NDIndexNone; +import ai.djl.ndarray.index.dim.NDIndexPick; +import ai.djl.ndarray.index.dim.NDIndexSlice; import ai.djl.ndarray.types.DataType; import java.util.ArrayList; @@ -185,7 +191,6 @@ public List getIndices() { * @param indices the indices to add similar to {@link #NDIndex(String, Object...)} * @param args arguments to replace the variable "{}" in the indices string. Can be an integer, * long, boolean {@link NDArray}, or integer {@link NDArray}. - * @return the updated {@link NDIndex} * @see #NDIndex(String, Object...) */ public final void addIndices(String indices, Object... args) { @@ -330,7 +335,7 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) { throw new IllegalArgumentException("Invalid argument index: " + indexItem); } // None - if (indexItem.equals("None")) { + if ("None".equals(indexItem)) { indices.add(new NDIndexNone()); return argIndex; } diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index 378c2d114ee..4f016d57559 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -53,6 +53,12 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { /** {@inheritDoc} */ @Override public NDArray get(NDArray array, NDIndex index) { + if (index.getRank() == 0) { + if (array.getShape().isScalar()) { + return array.duplicate(); + } + index.addAllDim(); + } return JniUtils.indexAdv(manager.from(array), index); } diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index 8bae7bb772a..2b56892a070 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -172,10 +172,6 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendFi API_BEGIN() auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); index_ptr->emplace_back((int) jidx); - using namespace std; - cout << "DEBUG" << endl; - cout << index_ptr->at(0) << endl; - cout << index_ptr->size() << endl; API_END() } diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index e879293f54f..6e4368760fe 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -81,53 +81,50 @@ public void testGet() { NDArray original = manager.create(new float[] {1f, 2f, 3f, 4f}, new Shape(2, 2)); Assert.assertEquals(original.get(new NDIndex()), original); - System.out.println("Engine is: " + manager.getEngine()); - - // NDArray getAt = original.get(0); - // NDArray expected = manager.create(new float[] {1f, 2f}); - // Assert.assertEquals(getAt, expected); - - // Assert.assertEquals(original.get("0,:"), expected); - // Assert.assertEquals(original.get("0,*"), expected); - // - // NDArray getSlice = original.get("1:"); - // expected = manager.create(new float[] {3f, 4f}, new Shape(1, 2)); - // Assert.assertEquals(getSlice, expected); - // - // NDArray getStepSlice = original.get("1::2"); - // Assert.assertEquals(getStepSlice, expected); - // - // original = manager.arange(120).reshape(2, 3, 4, 5); - // NDArray getEllipsis = original.get("0,2, ... "); - // expected = manager.arange(40, 60).reshape(4, 5); - // Assert.assertEquals(getEllipsis, expected); - // - // getEllipsis = original.get("...,0:2,2"); - // expected = - // manager.create(new int[] {2, 7, 22, 27, 42, 47, 62, 67, 82, 87, - // 102, 107}) - // .reshape(2, 3, 2); - // Assert.assertEquals(getEllipsis, expected); - // - // getEllipsis = original.get("1,...,2,3:5:2"); - // expected = manager.create(new int[] {73, 93, 113}).reshape(3, 1); - // Assert.assertEquals(getEllipsis, expected); - // - // getEllipsis = original.get("..."); - // Assert.assertEquals(getEllipsis, original); - // - // // get from boolean array - // original = manager.arange(10).reshape(2, 5); - // NDArray bool = manager.create(new boolean[] {true, false}); - // expected = manager.arange(5).reshape(1, 5); - // Assert.assertEquals(original.get(bool), expected); - // - // // get from int array - // original = manager.arange(1, 7f).reshape(-1, 3); - // NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); - // NDArray actual = original.get(index); - // expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); - // Assert.assertEquals(actual, expected); + NDArray getAt = original.get(0); + NDArray expected = manager.create(new float[] {1f, 2f}); + Assert.assertEquals(getAt, expected); + + Assert.assertEquals(original.get("0,:"), expected); + Assert.assertEquals(original.get("0,*"), expected); + + NDArray getSlice = original.get("1:"); + expected = manager.create(new float[] {3f, 4f}, new Shape(1, 2)); + Assert.assertEquals(getSlice, expected); + + NDArray getStepSlice = original.get("1::2"); + Assert.assertEquals(getStepSlice, expected); + + original = manager.arange(120).reshape(2, 3, 4, 5); + NDArray getEllipsis = original.get("0,2, ... "); + expected = manager.arange(40, 60).reshape(4, 5); + Assert.assertEquals(getEllipsis, expected); + + getEllipsis = original.get("...,0:2,2"); + expected = + manager.create(new int[] {2, 7, 22, 27, 42, 47, 62, 67, 82, 87, 102, 107}) + .reshape(2, 3, 2); + Assert.assertEquals(getEllipsis, expected); + + getEllipsis = original.get("1,...,2,3:5:2"); + expected = manager.create(new int[] {73, 93, 113}).reshape(3, 1); + Assert.assertEquals(getEllipsis, expected); + + getEllipsis = original.get("..."); + Assert.assertEquals(getEllipsis, original); + + // get from boolean array + original = manager.arange(10).reshape(2, 5); + NDArray bool = manager.create(new boolean[] {true, false}); + expected = manager.arange(5).reshape(1, 5); + Assert.assertEquals(original.get(bool), expected); + + // get from int array + original = manager.arange(1, 7f).reshape(-1, 3); + NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + NDArray actual = original.get(index); + expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); + Assert.assertEquals(actual, expected); } } From 10c465488dc84a3299efdaaebb02ab1e1b1af2f4 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 17 Jun 2022 10:17:31 -0700 Subject: [PATCH 06/15] Update api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java Co-authored-by: Frank Liu --- api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java index 1cfb32dcda6..9b958779220 100644 --- a/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java +++ b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java @@ -1,5 +1,5 @@ /* - * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance * with the License. A copy of the License is located at From 44023afafff2c41a15656e6a487fe0a0ecf0b0f7 Mon Sep 17 00:00:00 2001 From: Kexin Date: Fri, 17 Jun 2022 20:00:28 -0700 Subject: [PATCH 07/15] bug fixed --- api/src/main/java/ai/djl/ndarray/index/NDIndex.java | 4 +++- .../main/java/ai/djl/pytorch/engine/PtNDManager.java | 2 +- .../src/main/java/ai/djl/pytorch/jni/JniUtils.java | 10 +++++++++- .../main/java/ai/djl/pytorch/jni/PyTorchLibrary.java | 1 - 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index d604d0a941d..fb6e0ebbcf2 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -191,9 +191,10 @@ public List getIndices() { * @param indices the indices to add similar to {@link #NDIndex(String, Object...)} * @param args arguments to replace the variable "{}" in the indices string. Can be an integer, * long, boolean {@link NDArray}, or integer {@link NDArray}. + * @return the updated {@link NDIndex} * @see #NDIndex(String, Object...) */ - public final void addIndices(String indices, Object... args) { + public final NDIndex addIndices(String indices, Object... args) { String[] indexItems = indices.split(","); rank += indexItems.length; int argIndex = 0; @@ -215,6 +216,7 @@ public final void addIndices(String indices, Object... args) { if (argIndex != args.length) { throw new IllegalArgumentException("Incorrect number of index arguments"); } + return this; } /** diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 6a12af2f1df..4b725986bab 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -48,7 +48,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public PtNDArray from(NDArray array) { - if (array == null || array instanceof PtNDArray) { + if (array == null || array instanceof PtNDArray && array.getManager() == this) { return (PtNDArray) array; } return create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 433094b3660..6125ef8301f 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -15,7 +15,13 @@ import ai.djl.Device; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDIndex; -import ai.djl.ndarray.index.dim.*; +import ai.djl.ndarray.index.dim.NDIndexAll; +import ai.djl.ndarray.index.dim.NDIndexBooleans; +import ai.djl.ndarray.index.dim.NDIndexElement; +import ai.djl.ndarray.index.dim.NDIndexFixed; +import ai.djl.ndarray.index.dim.NDIndexNone; +import ai.djl.ndarray.index.dim.NDIndexPick; +import ai.djl.ndarray.index.dim.NDIndexSlice; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; @@ -37,6 +43,8 @@ import java.nio.file.Path; import java.util.Arrays; import java.util.HashSet; +import java.util.List; +import java.util.ListIterator; import java.util.Set; /** diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 6ba9b1838bf..bf65b2c93ee 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -12,7 +12,6 @@ */ package ai.djl.pytorch.jni; - import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; From fc5e5815a298128756f94c28594f0ba255526346 Mon Sep 17 00:00:00 2001 From: Kexin Date: Fri, 17 Jun 2022 22:00:18 -0700 Subject: [PATCH 08/15] bug fixed --- .../main/java/ai/djl/pytorch/jni/PyTorchLibrary.java | 2 +- .../native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index bf65b2c93ee..4406c2074ce 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -608,7 +608,7 @@ native void sgdUpdate( native void torchIndexAppendNoneEllipsis(long torchIndexHandle, boolean is_ellipsis); native void torchIndexAppendSlice( - long torchIndexHandle, long min, long max, long step, int null_slice_bin); + long torchIndexHandle, long min, long max, long step, int null_slice_binary); native void torchIndexAppendFixed(long torchIndexHandle, long idx); diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index 2b56892a070..d9520f524cc 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -152,16 +152,16 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendNo } JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendSlice(JNIEnv* env, jobject jthis, - jlong jtorch_index_handle, jlong jmin, jlong jmax, jlong jstep, jint jnull_slice_bin) { + jlong jtorch_index_handle, jlong jmin, jlong jmax, jlong jstep, jint jnull_slice_binary) { API_BEGIN() auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); - if (jnull_slice_bin == 0) { + if (jnull_slice_binary == 0) { index_ptr->emplace_back(torch::indexing::Slice(jmin, jmax, jstep)); - } else if (jnull_slice_bin == 1) { + } else if (jnull_slice_binary == 1) { index_ptr->emplace_back(torch::indexing::Slice(jmin, torch::indexing::None, jstep)); - } else if (jnull_slice_bin == 2) { + } else if (jnull_slice_binary == 2) { index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, jmax, jstep)); - } else if (jnull_slice_bin == 3) { + } else if (jnull_slice_binary == 3) { index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, torch::indexing::None, jstep)); } API_END() From 35ac95891d1906ec2211cd9f5f3f25169c8fc188 Mon Sep 17 00:00:00 2001 From: Kexin Date: Fri, 17 Jun 2022 23:06:28 -0700 Subject: [PATCH 09/15] Torch index type check: long, byte or boolean; restore testPick behaviour; The previous commit: PtNDManager.from() bug fixed. --- api/build.gradle | 1 + .../ai/djl/ndarray/index/dim/NDIndexTake.java | 45 +++++++++++++++++++ .../java/ai/djl/pytorch/jni/JniUtils.java | 17 ++++++- 3 files changed, 61 insertions(+), 2 deletions(-) create mode 100644 api/src/main/java/ai/djl/ndarray/index/dim/NDIndexTake.java diff --git a/api/build.gradle b/api/build.gradle index 5f16a668ed9..6e29aa79389 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -9,6 +9,7 @@ dependencies { } testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" testRuntimeOnly project(":engines:pytorch:pytorch-model-zoo") + testRuntimeOnly project(":engines:pytorch:pytorch-jni") } javadoc { diff --git a/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexTake.java b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexTake.java new file mode 100644 index 00000000000..5ce76fe3adc --- /dev/null +++ b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexTake.java @@ -0,0 +1,45 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.ndarray.index.dim; + +import ai.djl.ndarray.NDArray; + +/** An {@link NDIndexElement} that gets elements by index in the specified axis. */ +public class NDIndexTake implements NDIndexElement { + + private NDArray index; + + /** + * Constructs a pick. + * + * @param index the index to pick + */ + public NDIndexTake(NDArray index) { + this.index = index; + } + + /** {@inheritDoc} */ + @Override + public int getRank() { + return 1; + } + + /** + * Returns the index to pick. + * + * @return the index to pick + */ + public NDArray getIndex() { + return index; + } +} diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 6125ef8301f..26040b286c1 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -22,6 +22,8 @@ import ai.djl.ndarray.index.dim.NDIndexNone; import ai.djl.ndarray.index.dim.NDIndexPick; import ai.djl.ndarray.index.dim.NDIndexSlice; +import ai.djl.ndarray.index.dim.NDIndexTake; +import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.types.DataType; import ai.djl.ndarray.types.Shape; import ai.djl.ndarray.types.SparseFormat; @@ -381,9 +383,20 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) { } else if (elem instanceof NDIndexBooleans) { PtNDArray index_arr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, index_arr.getHandle()); - } else if (elem instanceof NDIndexPick) { - PtNDArray index_arr = (PtNDArray) ((NDIndexPick) elem).getIndex(); + } else if (elem instanceof NDIndexTake) { + PtNDArray index_arr = (PtNDArray) ((NDIndexTake) elem).getIndex(); + if (index_arr.getDataType() != DataType.INT64) { + index_arr = index_arr.toType(DataType.INT64, true); + } PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, index_arr.getHandle()); + } else if (elem instanceof NDIndexPick) { + //noinspection OptionalGetWithoutIsPresent + NDIndexFullPick fullPick = + NDIndexFullPick.fromIndex(index, ndArray.getShape()).get(); + return pick( + ndArray, + ndArray.getManager().from(fullPick.getIndices()), + fullPick.getAxis()); } } if (indices.size() == index.getEllipsisIndex()) { From b26864951bb868ba30ecd08b89f356de5e1f34c6 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 23 Jun 2022 12:29:42 -0700 Subject: [PATCH 10/15] Restore the get(NDArray, NDIndex) --- .../ai/djl/ndarray/index/NDArrayIndexer.java | 29 ++++++++++++++- .../java/ai/djl/ndarray/index/NDIndex.java | 4 ++- .../ai/djl/mxnet/engine/MxNDArrayIndexer.java | 35 ------------------- .../java/ai/djl/pytorch/jni/JniUtils.java | 1 - .../tensorflow/engine/TfNDArrayIndexer.java | 7 ---- 5 files changed, 31 insertions(+), 45 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java index d71a9b9acd7..47994b22cda 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDArrayIndexer.java @@ -20,6 +20,7 @@ import ai.djl.ndarray.index.full.NDIndexFullSlice; import java.util.List; +import java.util.Optional; /** A helper class for {@link NDArray} implementations for operations with an {@link NDIndex}. */ public abstract class NDArrayIndexer { @@ -49,7 +50,33 @@ public abstract class NDArrayIndexer { * @param index the index to get * @return the subarray */ - public abstract NDArray get(NDArray array, NDIndex index); + public NDArray get(NDArray array, NDIndex index) { + if (index.getRank() == 0 && array.getShape().isScalar()) { + return array.duplicate(); + } + + // use booleanMask for NDIndexBooleans case + List indices = index.getIndices(); + if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) { + if (indices.size() != 1) { + throw new IllegalArgumentException( + "get() currently doesn't support more that one boolean NDArray"); + } + return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex()); + } + + Optional fullPick = NDIndexFullPick.fromIndex(index, array.getShape()); + if (fullPick.isPresent()) { + return get(array, fullPick.get()); + } + + Optional fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape()); + if (fullSlice.isPresent()) { + return get(array, fullSlice.get()); + } + throw new UnsupportedOperationException( + "get() currently supports all, fixed, and slices indices"); + } /** * Sets the values of the array at the fullSlice with an array. diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index fb6e0ebbcf2..cc81caf766b 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -106,6 +106,8 @@ public NDIndex() { * * // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection. * assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4)); + * + * // TODO: Add doc for the new indexings * * * @param indices a comma separated list of indices corresponding to either subsections, @@ -336,7 +338,7 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) { if (!m.matches()) { throw new IllegalArgumentException("Invalid argument index: " + indexItem); } - // None + // "None" case if ("None".equals(indexItem)) { indices.add(new NDIndexNone()); return argIndex; diff --git a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java index 4465a9526ab..5ac0274844e 100644 --- a/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java +++ b/engines/mxnet/mxnet-engine/src/main/java/ai/djl/mxnet/engine/MxNDArrayIndexer.java @@ -15,15 +15,10 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDList; import ai.djl.ndarray.index.NDArrayIndexer; -import ai.djl.ndarray.index.NDIndex; -import ai.djl.ndarray.index.dim.NDIndexBooleans; -import ai.djl.ndarray.index.dim.NDIndexElement; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; import ai.djl.ndarray.types.Shape; -import java.util.List; -import java.util.Optional; import java.util.Stack; /** The {@link NDArrayIndexer} used by the {@link MxNDArray}. */ @@ -66,36 +61,6 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { return result; } - /** {@inheritDoc} */ - @Override - public NDArray get(NDArray array, NDIndex index) { - if (index.getRank() == 0 && array.getShape().isScalar()) { - return array.duplicate(); - } - - // use booleanMask for NDIndexBooleans case - List indices = index.getIndices(); - if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) { - if (indices.size() != 1) { - throw new IllegalArgumentException( - "get() currently doesn't support more that one boolean NDArray"); - } - return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex()); - } - - Optional fullPick = NDIndexFullPick.fromIndex(index, array.getShape()); - if (fullPick.isPresent()) { - return get(array, fullPick.get()); - } - - Optional fullSlice = NDIndexFullSlice.fromIndex(index, array.getShape()); - if (fullSlice.isPresent()) { - return get(array, fullSlice.get()); - } - throw new UnsupportedOperationException( - "get() currently supports only all, fixed, and slices indices in MXNet engine"); - } - /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 26040b286c1..c4d3f76f3b0 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -436,7 +436,6 @@ public static PtNDArray take(PtNDArray ndArray, PtNDArray index) { if (index.getDataType() != DataType.INT64) { index = index.toType(DataType.INT64, true); } - System.out.println(PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); return new PtNDArray( ndArray.getManager(), PyTorchLibrary.LIB.torchTake(ndArray.getHandle(), index.getHandle())); diff --git a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java index 3562a50dc7b..73b3a4c5d54 100644 --- a/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java +++ b/engines/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfNDArrayIndexer.java @@ -14,7 +14,6 @@ import ai.djl.ndarray.NDArray; import ai.djl.ndarray.index.NDArrayIndexer; -import ai.djl.ndarray.index.NDIndex; import ai.djl.ndarray.index.full.NDIndexFullPick; import ai.djl.ndarray.index.full.NDIndexFullSlice; @@ -57,12 +56,6 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { } } - /** {@inheritDoc} */ - @Override - public NDArray get(NDArray array, NDIndex index) { - throw new UnsupportedOperationException("Not implemented"); - } - /** {@inheritDoc} */ @Override public void set(NDArray array, NDIndexFullSlice fullSlice, NDArray value) { From 110f400610a1f4b3e0475f7aa207cb2722078b70 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 23 Jun 2022 17:46:03 -0700 Subject: [PATCH 11/15] change at::indexing to torch::indexing; testRuntimeOnly project(":engines:pytorch:pytorch-jni --- api/build.gradle | 1 - .../ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/api/build.gradle b/api/build.gradle index 6e29aa79389..5f16a668ed9 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -9,7 +9,6 @@ dependencies { } testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" testRuntimeOnly project(":engines:pytorch:pytorch-model-zoo") - testRuntimeOnly project(":engines:pytorch:pytorch-jni") } javadoc { diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index d9520f524cc..2d933a58aca 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -123,7 +123,7 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndex(JNIEnv JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexInit(JNIEnv* env, jobject jthis, jint jsize) { API_BEGIN() - std::vector *index_ptr = new std::vector; + std::vector *index_ptr = new std::vector; index_ptr->reserve(jsize); return reinterpret_cast(index_ptr); API_END_RETURN() @@ -133,7 +133,7 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexReturn( jlong jhandle, jlong jtorch_index_handle) { API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); torch::Tensor* ret_ptr = new torch::Tensor(tensor_ptr->index(*index_ptr)); return reinterpret_cast(ret_ptr); API_END_RETURN() @@ -142,7 +142,7 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexReturn( JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendNoneEllipsis(JNIEnv* env, jobject jthis, jlong jtorch_index_handle, jboolean jis_ellipsis) { API_BEGIN() - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); if (jis_ellipsis) { index_ptr->emplace_back(torch::indexing::Ellipsis); } else { @@ -154,7 +154,7 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendNo JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendSlice(JNIEnv* env, jobject jthis, jlong jtorch_index_handle, jlong jmin, jlong jmax, jlong jstep, jint jnull_slice_binary) { API_BEGIN() - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); if (jnull_slice_binary == 0) { index_ptr->emplace_back(torch::indexing::Slice(jmin, jmax, jstep)); } else if (jnull_slice_binary == 1) { @@ -170,7 +170,7 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendSl JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendFixed(JNIEnv* env, jobject jthis, jlong jtorch_index_handle, jlong jidx) { API_BEGIN() - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); index_ptr->emplace_back((int) jidx); API_END() } @@ -178,7 +178,7 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendFi JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendArray(JNIEnv* env, jobject jthis, jlong jtorch_index_handle, jlong jarray) { API_BEGIN() - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); auto* array_ptr = reinterpret_cast(jarray); index_ptr->emplace_back(*array_ptr); API_END() From 2e64ac2195aafce98464b258cbb2c1773b147052 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 23 Jun 2022 18:00:57 -0700 Subject: [PATCH 12/15] Add :engines:pytorch:pytorch-jni --- api/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/api/build.gradle b/api/build.gradle index 5f16a668ed9..6e29aa79389 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -9,6 +9,7 @@ dependencies { } testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" testRuntimeOnly project(":engines:pytorch:pytorch-model-zoo") + testRuntimeOnly project(":engines:pytorch:pytorch-jni") } javadoc { From 920ff8db506d4109d7721e8e39cb289551036453 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Fri, 24 Jun 2022 17:41:13 -0700 Subject: [PATCH 13/15] testIndexationUsesSpecificManager add manager checking into PtNDArrayIndexer get(NDArray, NDIndex) --- .../main/java/ai/djl/ndarray/NDManager.java | 2 +- engines/dlr/dlr-engine/build.gradle | 1 + .../paddlepaddle-model-zoo/build.gradle | 1 + .../djl/pytorch/engine/PtNDArrayIndexer.java | 13 ++++++++-- .../ai/djl/pytorch/engine/PtNDManager.java | 2 +- .../java/ai/djl/pytorch/jni/JniUtils.java | 24 +++++++++++-------- .../ai/djl/pytorch/jni/PyTorchLibrary.java | 4 ++-- extensions/audio/build.gradle | 1 + extensions/opencv/build.gradle | 1 + extensions/tablesaw/build.gradle | 1 + 10 files changed, 34 insertions(+), 16 deletions(-) diff --git a/api/src/main/java/ai/djl/ndarray/NDManager.java b/api/src/main/java/ai/djl/ndarray/NDManager.java index cdf76a6595e..f9bac2fb1e2 100644 --- a/api/src/main/java/ai/djl/ndarray/NDManager.java +++ b/api/src/main/java/ai/djl/ndarray/NDManager.java @@ -165,7 +165,7 @@ static NDManager subManagerOf(NDResource resource) { ByteBuffer allocateDirect(int capacity); /** - * Creates a new {@code NDArray} if the input {@link NDArray} is from external engine. + * Creates a new {@code NDArray} if the input {@link NDArray} is from an external engine. * * @param array the input {@code NDArray} * @return a new {@code NDArray} if the input {@code NDArray} is from external engine diff --git a/engines/dlr/dlr-engine/build.gradle b/engines/dlr/dlr-engine/build.gradle index 42401f9a302..45fe1852cdb 100644 --- a/engines/dlr/dlr-engine/build.gradle +++ b/engines/dlr/dlr-engine/build.gradle @@ -9,6 +9,7 @@ dependencies { } testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" testRuntimeOnly project(":engines:pytorch:pytorch-engine") + testRuntimeOnly project(":engines:pytorch:pytorch-jni") } compileJava.dependsOn(processResources) diff --git a/engines/paddlepaddle/paddlepaddle-model-zoo/build.gradle b/engines/paddlepaddle/paddlepaddle-model-zoo/build.gradle index c23f27184aa..a04fa364030 100644 --- a/engines/paddlepaddle/paddlepaddle-model-zoo/build.gradle +++ b/engines/paddlepaddle/paddlepaddle-model-zoo/build.gradle @@ -10,6 +10,7 @@ dependencies { testImplementation "org.slf4j:slf4j-simple:${slf4j_version}" testImplementation(project(":testing")) testRuntimeOnly project(":engines:pytorch:pytorch-engine") + testRuntimeOnly project(":engines:pytorch:pytorch-jni") } task syncS3(type: Exec) { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index 4f016d57559..fe59a0acff2 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -55,11 +55,20 @@ public NDArray get(NDArray array, NDIndexFullSlice fullSlice) { public NDArray get(NDArray array, NDIndex index) { if (index.getRank() == 0) { if (array.getShape().isScalar()) { - return array.duplicate(); + return array.getManager() == manager + ? array.duplicate() + : manager.create( + array.toByteBuffer(), array.getShape(), array.getDataType()); } index.addAllDim(); } - return JniUtils.indexAdv(manager.from(array), index); + if (array == null || array instanceof PtNDArray && array.getManager() == manager) { + return JniUtils.indexAdv((PtNDArray) array, index); + } else { + PtNDArray arrayNew = + manager.create(array.toByteBuffer(), array.getShape(), array.getDataType()); + return JniUtils.indexAdv(arrayNew, index); + } } /** {@inheritDoc} */ diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java index 4b725986bab..6a12af2f1df 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDManager.java @@ -48,7 +48,7 @@ public ByteBuffer allocateDirect(int capacity) { /** {@inheritDoc} */ @Override public PtNDArray from(NDArray array) { - if (array == null || array instanceof PtNDArray && array.getManager() == this) { + if (array == null || array instanceof PtNDArray) { return (PtNDArray) array; } return create(array.toByteBuffer(), array.getShape(), array.getDataType()); diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index c4d3f76f3b0..0c25d2563f0 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -349,7 +349,11 @@ public static PtNDArray index( ndArray.getHandle(), minIndices, maxIndices, stepIndices)); } + @SuppressWarnings("OptionalGetWithoutIsPresent") public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) { + if (ndArray == null) { + return ndArray; + } List indices = index.getIndices(); long torchIndexHandle = PyTorchLibrary.LIB.torchIndexInit(indices.size()); ListIterator it = indices.listIterator(); @@ -365,8 +369,8 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) { Long min = ((NDIndexSlice) elem).getMin(); Long max = ((NDIndexSlice) elem).getMax(); Long step = ((NDIndexSlice) elem).getStep(); - int null_slice_bin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); - // null_slice_bin encodes whether (min, max) is null: + int nullSliceBin = (min == null ? 1 : 0) * 2 + (max == null ? 1 : 0); + // nullSliceBin encodes whether the slice (min, max) is null: // is_null == 1, ! is_null == 0; // 0b11 == 3, 0b10 = 2, ... PyTorchLibrary.LIB.torchIndexAppendSlice( @@ -374,23 +378,23 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) { min == null ? 0 : min, max == null ? 0 : max, step == null ? 1 : step, - null_slice_bin); + nullSliceBin); } else if (elem instanceof NDIndexAll) { PyTorchLibrary.LIB.torchIndexAppendSlice(torchIndexHandle, 0, 0, 1, 3); } else if (elem instanceof NDIndexFixed) { PyTorchLibrary.LIB.torchIndexAppendFixed( torchIndexHandle, ((NDIndexFixed) elem).getIndex()); } else if (elem instanceof NDIndexBooleans) { - PtNDArray index_arr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); - PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, index_arr.getHandle()); + PtNDArray indexArr = (PtNDArray) ((NDIndexBooleans) elem).getIndex(); + PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle()); } else if (elem instanceof NDIndexTake) { - PtNDArray index_arr = (PtNDArray) ((NDIndexTake) elem).getIndex(); - if (index_arr.getDataType() != DataType.INT64) { - index_arr = index_arr.toType(DataType.INT64, true); + PtNDArray indexArr = (PtNDArray) ((NDIndexTake) elem).getIndex(); + if (indexArr.getDataType() != DataType.INT64) { + indexArr = indexArr.toType(DataType.INT64, true); } - PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, index_arr.getHandle()); + PyTorchLibrary.LIB.torchIndexAppendArray(torchIndexHandle, indexArr.getHandle()); } else if (elem instanceof NDIndexPick) { - //noinspection OptionalGetWithoutIsPresent + // Backward compatible NDIndexFullPick fullPick = NDIndexFullPick.fromIndex(index, ndArray.getShape()).get(); return pick( diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java index 4406c2074ce..11d2a95a586 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java @@ -605,10 +605,10 @@ native void sgdUpdate( native long torchIndexReturn(long handle, long torchIndexHandle); - native void torchIndexAppendNoneEllipsis(long torchIndexHandle, boolean is_ellipsis); + native void torchIndexAppendNoneEllipsis(long torchIndexHandle, boolean isEllipsis); native void torchIndexAppendSlice( - long torchIndexHandle, long min, long max, long step, int null_slice_binary); + long torchIndexHandle, long min, long max, long step, int nullSliceBinary); native void torchIndexAppendFixed(long torchIndexHandle, long idx); diff --git a/extensions/audio/build.gradle b/extensions/audio/build.gradle index 8b6d60530c3..266cb9dd447 100644 --- a/extensions/audio/build.gradle +++ b/extensions/audio/build.gradle @@ -31,6 +31,7 @@ dependencies { testImplementation project(":testing") testRuntimeOnly "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}" testRuntimeOnly project(":engines:pytorch:pytorch-engine") + testRuntimeOnly project(":engines:pytorch:pytorch-jni") } publishing { diff --git a/extensions/opencv/build.gradle b/extensions/opencv/build.gradle index 7288cccbdb7..05dddad24a3 100644 --- a/extensions/opencv/build.gradle +++ b/extensions/opencv/build.gradle @@ -11,6 +11,7 @@ dependencies { testRuntimeOnly "org.apache.logging.log4j:log4j-slf4j-impl:${log4j_slf4j_version}" testRuntimeOnly project(":engines:pytorch:pytorch-model-zoo") + testRuntimeOnly project(":engines:pytorch:pytorch-jni") } publishing { diff --git a/extensions/tablesaw/build.gradle b/extensions/tablesaw/build.gradle index c3ed876da4a..2e71879f9ef 100644 --- a/extensions/tablesaw/build.gradle +++ b/extensions/tablesaw/build.gradle @@ -15,6 +15,7 @@ dependencies { // testRuntimeOnly "tech.tablesaw:tablesaw-html:${tablesaw_version}" // testRuntimeOnly "tech.tablesaw:tablesaw-json:${tablesaw_version}" testRuntimeOnly project(":engines:pytorch:pytorch-engine") + testRuntimeOnly project(":engines:pytorch:pytorch-jni") } publishing { From db763c9fcd96800233aacd99e8f86bd0f50b7314 Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Sat, 25 Jun 2022 12:05:14 -0700 Subject: [PATCH 14/15] Add NDIndeTest fix NDArray.get(index) --- api/src/main/java/ai/djl/ndarray/NDArray.java | 6 +---- .../java/ai/djl/ndarray/index/NDIndex.java | 20 ++++++++++------ .../{NDIndexNone.java => NDIndexNull.java} | 2 +- .../djl/pytorch/engine/PtNDArrayIndexer.java | 1 + .../java/ai/djl/pytorch/jni/JniUtils.java | 4 ++-- .../tests/ndarray/NDIndexTest.java | 23 ++++++++++++++++--- 6 files changed, 38 insertions(+), 18 deletions(-) rename api/src/main/java/ai/djl/ndarray/index/dim/{NDIndexNone.java => NDIndexNull.java} (93%) diff --git a/api/src/main/java/ai/djl/ndarray/NDArray.java b/api/src/main/java/ai/djl/ndarray/NDArray.java index da5439be7e1..77162907f5a 100644 --- a/api/src/main/java/ai/djl/ndarray/NDArray.java +++ b/api/src/main/java/ai/djl/ndarray/NDArray.java @@ -530,11 +530,7 @@ default NDArray get(NDManager manager, NDIndex index) { * @return the partial {@code NDArray} */ default NDArray get(NDArray index) { - if (index.getDataType() == DataType.BOOLEAN) { - return get(new NDIndex().addBooleanIndex(index)); - } else { - return take(index); - } + return get(new NDIndex("{}", index)); } /** diff --git a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java index cc81caf766b..8b8b4014697 100644 --- a/api/src/main/java/ai/djl/ndarray/index/NDIndex.java +++ b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java @@ -17,9 +17,10 @@ import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.dim.NDIndexElement; import ai.djl.ndarray.index.dim.NDIndexFixed; -import ai.djl.ndarray.index.dim.NDIndexNone; +import ai.djl.ndarray.index.dim.NDIndexNull; import ai.djl.ndarray.index.dim.NDIndexPick; import ai.djl.ndarray.index.dim.NDIndexSlice; +import ai.djl.ndarray.index.dim.NDIndexTake; import ai.djl.ndarray.types.DataType; import java.util.ArrayList; @@ -51,7 +52,7 @@ public class NDIndex { /* Android regex requires escape } char as well */ private static final Pattern ITEM_PATTERN = Pattern.compile( - "(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})|None"); + "(\\*)|((-?\\d+|\\{\\})?:(-?\\d+|\\{\\})?(:(-?\\d+|\\{\\}))?)|(-?\\d+|\\{\\})|null"); private int rank; private List indices; @@ -107,7 +108,9 @@ public NDIndex() { * // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection. * assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4)); * - * // TODO: Add doc for the new indexings + * // Uses null to add an extra axis to the output array + * assertEquals(a.get(new NDIndex(":2, null, 0, :2")).getShape(), new Shape(2, 1, 2)); + * * * * @param indices a comma separated list of indices corresponding to either subsections, @@ -338,9 +341,9 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) { if (!m.matches()) { throw new IllegalArgumentException("Invalid argument index: " + indexItem); } - // "None" case - if ("None".equals(indexItem)) { - indices.add(new NDIndexNone()); + // "null" case + if ("null".equals(indexItem)) { + indices.add(new NDIndexNull()); return argIndex; } // "*" case @@ -366,9 +369,12 @@ private int addIndexItem(String indexItem, int argIndex, Object[] args) { indices.add(new NDIndexBooleans(array)); return argIndex + 1; } else if (array.getDataType().isInteger()) { - indices.add(new NDIndexPick(array)); + indices.add(new NDIndexTake(array)); return argIndex + 1; } + } else if (arg == null) { + indices.add(new NDIndexNull()); + return argIndex + 1; } throw new IllegalArgumentException("Unknown argument: " + arg); } else { diff --git a/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNull.java similarity index 93% rename from api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java rename to api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNull.java index 9b958779220..d561824755c 100644 --- a/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNone.java +++ b/api/src/main/java/ai/djl/ndarray/index/dim/NDIndexNull.java @@ -13,7 +13,7 @@ package ai.djl.ndarray.index.dim; /** An {@code NDIndexElement} to return all values in a particular dimension. */ -public class NDIndexNone implements NDIndexElement { +public class NDIndexNull implements NDIndexElement { /** {@inheritDoc} */ @Override diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java index fe59a0acff2..07daec093b3 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtNDArrayIndexer.java @@ -62,6 +62,7 @@ public NDArray get(NDArray array, NDIndex index) { } index.addAllDim(); } + if (array == null || array instanceof PtNDArray && array.getManager() == manager) { return JniUtils.indexAdv((PtNDArray) array, index); } else { diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java index 0c25d2563f0..8e97adde6ef 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/JniUtils.java @@ -19,7 +19,7 @@ import ai.djl.ndarray.index.dim.NDIndexBooleans; import ai.djl.ndarray.index.dim.NDIndexElement; import ai.djl.ndarray.index.dim.NDIndexFixed; -import ai.djl.ndarray.index.dim.NDIndexNone; +import ai.djl.ndarray.index.dim.NDIndexNull; import ai.djl.ndarray.index.dim.NDIndexPick; import ai.djl.ndarray.index.dim.NDIndexSlice; import ai.djl.ndarray.index.dim.NDIndexTake; @@ -363,7 +363,7 @@ public static PtNDArray indexAdv(PtNDArray ndArray, NDIndex index) { } NDIndexElement elem = it.next(); - if (elem instanceof NDIndexNone) { + if (elem instanceof NDIndexNull) { PyTorchLibrary.LIB.torchIndexAppendNoneEllipsis(torchIndexHandle, false); } else if (elem instanceof NDIndexSlice) { Long min = ((NDIndexSlice) elem).getMin(); diff --git a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java index 6e4368760fe..7d9f79fc8b8 100644 --- a/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java +++ b/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java @@ -120,10 +120,27 @@ public void testGet() { Assert.assertEquals(original.get(bool), expected); // get from int array - original = manager.arange(1, 7f).reshape(-1, 3); - NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2)); + original = manager.arange(1, 7f).reshape(-1, 2); + NDArray index = manager.create(new long[] {0, 0, 1, 2}, new Shape(2, 2)); NDArray actual = original.get(index); - expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2)); + expected = manager.create(new float[] {1, 2, 1, 2, 3, 4, 5, 6}, new Shape(2, 2, 2)); + Assert.assertEquals(actual, expected); + + // indexing with boolean, broadcast int array and slice + original = manager.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3); + NDArray bool1 = manager.create(new boolean[] {true, false, true}); + NDArray index1 = manager.create(new long[] {2, 2}, new Shape(1, 2)); + NDArray index2 = manager.create(new long[] {0, 1}, new Shape(1, 2)); + actual = original.get(":{}, {}, {}, {}", 2, index1, bool1, index2); + expected = manager.create(new int[] {18, 25, 45, 52}, new Shape(2, 1, 2)); + Assert.assertEquals(actual, expected); + + // indexing with null, broadcast int array and slice + original = manager.arange(3 * 3 * 3).reshape(3, 3, 3); + index1 = manager.create(new long[] {0, 1}, new Shape(2)); + index2 = manager.create(new long[] {0, 0, 2, 1}, new Shape(2, 2)); + actual = original.get(":{}, {}, {}, {}", 2, index1, index2, null); + expected = manager.create(new int[] {0, 3, 2, 4, 9, 12, 11, 13}, new Shape(2, 2, 2, 1)); Assert.assertEquals(actual, expected); } } From 3f0c4e80753a0a458813b1bd320dd949d9f181c3 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Mon, 27 Jun 2022 10:46:15 -0700 Subject: [PATCH 15/15] Format c++ code Change-Id: I5a7287719a8deedbfefa4181dc79e72d78410d49 --- ...i_djl_pytorch_jni_PyTorchLibrary_tensor.cc | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc index 2d933a58aca..48494010618 100644 --- a/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc +++ b/engines/pytorch/pytorch-native/src/main/native/ai_djl_pytorch_jni_PyTorchLibrary_tensor.cc @@ -123,30 +123,30 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndex(JNIEnv JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexInit(JNIEnv* env, jobject jthis, jint jsize) { API_BEGIN() - std::vector *index_ptr = new std::vector; + std::vector* index_ptr = new std::vector; index_ptr->reserve(jsize); return reinterpret_cast(index_ptr); API_END_RETURN() } -JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexReturn(JNIEnv* env, jobject jthis, - jlong jhandle, jlong jtorch_index_handle) { +JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexReturn( + JNIEnv* env, jobject jthis, jlong jhandle, jlong jtorch_index_handle) { API_BEGIN() const auto* tensor_ptr = reinterpret_cast(jhandle); - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast*>(jtorch_index_handle); torch::Tensor* ret_ptr = new torch::Tensor(tensor_ptr->index(*index_ptr)); return reinterpret_cast(ret_ptr); API_END_RETURN() } -JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendNoneEllipsis(JNIEnv* env, jobject jthis, - jlong jtorch_index_handle, jboolean jis_ellipsis) { +JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendNoneEllipsis( + JNIEnv* env, jobject jthis, jlong jtorch_index_handle, jboolean jis_ellipsis) { API_BEGIN() - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast*>(jtorch_index_handle); if (jis_ellipsis) { - index_ptr->emplace_back(torch::indexing::Ellipsis); + index_ptr->emplace_back(torch::indexing::Ellipsis); } else { - index_ptr->emplace_back(torch::indexing::None); + index_ptr->emplace_back(torch::indexing::None); } API_END() } @@ -154,34 +154,34 @@ JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendNo JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendSlice(JNIEnv* env, jobject jthis, jlong jtorch_index_handle, jlong jmin, jlong jmax, jlong jstep, jint jnull_slice_binary) { API_BEGIN() - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast*>(jtorch_index_handle); if (jnull_slice_binary == 0) { - index_ptr->emplace_back(torch::indexing::Slice(jmin, jmax, jstep)); + index_ptr->emplace_back(torch::indexing::Slice(jmin, jmax, jstep)); } else if (jnull_slice_binary == 1) { - index_ptr->emplace_back(torch::indexing::Slice(jmin, torch::indexing::None, jstep)); + index_ptr->emplace_back(torch::indexing::Slice(jmin, torch::indexing::None, jstep)); } else if (jnull_slice_binary == 2) { - index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, jmax, jstep)); + index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, jmax, jstep)); } else if (jnull_slice_binary == 3) { - index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, torch::indexing::None, jstep)); + index_ptr->emplace_back(torch::indexing::Slice(torch::indexing::None, torch::indexing::None, jstep)); } API_END() } -JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendFixed(JNIEnv* env, jobject jthis, - jlong jtorch_index_handle, jlong jidx) { +JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendFixed( + JNIEnv* env, jobject jthis, jlong jtorch_index_handle, jlong jidx) { API_BEGIN() - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); + auto* index_ptr = reinterpret_cast*>(jtorch_index_handle); index_ptr->emplace_back((int) jidx); API_END() } -JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendArray(JNIEnv* env, jobject jthis, - jlong jtorch_index_handle, jlong jarray) { -API_BEGIN() - auto* index_ptr = reinterpret_cast *>(jtorch_index_handle); +JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexAppendArray( + JNIEnv* env, jobject jthis, jlong jtorch_index_handle, jlong jarray) { + API_BEGIN() + auto* index_ptr = reinterpret_cast*>(jtorch_index_handle); auto* array_ptr = reinterpret_cast(jarray); index_ptr->emplace_back(*array_ptr); -API_END() + API_END() } JNIEXPORT void JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchIndexPut(JNIEnv* env, jobject jthis, jlong jhandle,