Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of take on MXNet engine #1649

Merged
merged 14 commits into from
Jun 1, 2022
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ default NDArray get(long... indices) {
NDArray gather(NDArray index, int axis);

/**
* Returns a partial {@code NDArray} pointed by the indexed array, according to linear indexing.
* Returns a partial {@code NDArray} pointed by the indexed array according to linear indexing,
* and the of output is of the same shape as index.
*
* @param index picks the elements of an NDArray to the same position as index
* @return the partial {@code NDArray} of the same shape as index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public NDArray get(NDArray array, NDIndex index) {
if (!indices.isEmpty() && indices.get(0) instanceof NDIndexBooleans) {
if (indices.size() != 1) {
throw new IllegalArgumentException(
"get() currently didn't support more that one boolean NDArray");
"get() currently doesn't support more that one boolean NDArray");
}
return array.booleanMask(((NDIndexBooleans) indices.get(0)).getIndex());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ public NDArray gather(NDArray index, int axis) {
/** {@inheritDoc} */
@Override
public NDArray take(NDArray index) {
throw new UnsupportedOperationException("Not implemented yet.");
MxOpParams params = new MxOpParams();
params.add("mode", "wrap");
return manager.invoke("take", new NDList(this.flatten(), index), params).singletonOrThrow();
}

/** {@inheritDoc} */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ public PtNDArray booleanMask(NDArray index, int axis) {
// Result is flattened since shape is undetermined
return JniUtils.booleanMask(this, manager.from(index));
} else if (indexShape.equals(getShape().slice(axis))) {
// index will be broadcasted by default
// index will be broadcast by default
try (PtNDArray flattedResult = JniUtils.booleanMask(this, manager.from(index))) {
// Shape recovery
Shape remainder = getShape().slice(0, axis);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ public void testPick() {
@Test
public void testGather() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(20f).reshape(-1, 4);
NDArray original = manager.arange(20f).reshape(-1, 4);
NDArray index = manager.create(new long[] {0, 0, 2, 1, 1, 2}, new Shape(3, 2));
NDArray actual = arr.gather(index, 1);
NDArray actual = original.gather(index, 1);
NDArray expected = manager.create(new float[] {0, 0, 6, 5, 9, 10}, new Shape(3, 2));
Assert.assertEquals(actual, expected);
}
Expand All @@ -66,9 +66,9 @@ public void testGather() {
@Test
public void testTake() {
try (NDManager manager = NDManager.newBaseManager()) {
NDArray arr = manager.arange(1, 7f).reshape(-1, 3);
NDArray original = manager.arange(1, 7f).reshape(-1, 3);
NDArray index = manager.create(new long[] {0, 4, 1, 2}, new Shape(2, 2));
NDArray actual = arr.take(index);
NDArray actual = original.take(index);
NDArray expected = manager.create(new float[] {1, 5, 2, 3}, new Shape(2, 2));
Assert.assertEquals(actual, expected);
}
Expand Down