-
Notifications
You must be signed in to change notification settings - Fork 685
/
Copy pathNDIndexFullPick.java
100 lines (92 loc) · 3.41 KB
/
NDIndexFullPick.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray.index.full;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.index.dim.NDIndexAll;
import ai.djl.ndarray.index.dim.NDIndexElement;
import ai.djl.ndarray.index.dim.NDIndexPick;
import ai.djl.ndarray.index.dim.NDIndexTake;
import ai.djl.ndarray.types.Shape;
import java.util.Optional;
/** A simplified representation of a pick-based {@link NDIndex}. */
public final class NDIndexFullPick {
private NDArray indices;
private int axis;
/**
* Constructs a new {@link NDIndexFullPick}.
*
* @param indices the indices to pick
* @param axis the axis to pick at
*/
private NDIndexFullPick(NDArray indices, int axis) {
this.indices = indices;
this.axis = axis;
}
/**
* Returns (if possible) the {@link NDIndexFullPick} representation of an {@link NDIndex}.
*
* @param index the index to represent
* @param target the shape of the array to index
* @return the full pick representation or nothing if it can't represent the index
*/
public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
int axis = 0;
NDIndexFullPick fullPick = null;
for (NDIndexElement el : index.getIndices()) {
if (el instanceof NDIndexAll) {
axis++;
} else if (el instanceof NDIndexPick) {
if (fullPick != null) {
// Don't support multiple picks
throw new UnsupportedOperationException(
"Only one pick per get is currently supported");
}
NDArray indexElem = ((NDIndexPick) el).getIndex();
fullPick = new NDIndexFullPick(indexElem, axis);
} else if (el instanceof NDIndexTake) {
if (fullPick != null) {
// Don't support multiple picks
throw new UnsupportedOperationException(
"Only one pick per get is currently supported");
}
NDArray indexElem = ((NDIndexTake) el).getIndex();
if (!indexElem.getShape().isRankOne()) {
throw new UnsupportedOperationException(
"Only rank-1 indexing array is supported for pick");
}
fullPick = new NDIndexFullPick(indexElem, axis);
} else {
// Invalid dim for fullPick
return Optional.empty();
}
}
return Optional.ofNullable(fullPick);
}
/**
* Returns the indices to pick.
*
* @return the indices to pick
*/
public NDArray getIndices() {
return indices;
}
/**
* Returns the axis to pick.
*
* @return the axis to pick
*/
public int getAxis() {
return axis;
}
}