-
Notifications
You must be signed in to change notification settings - Fork 685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bulk Batch creation #1869
Bulk Batch creation #1869
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1869 +/- ##
============================================
- Coverage 72.08% 70.10% -1.99%
- Complexity 5126 5867 +741
============================================
Files 473 576 +103
Lines 21970 26024 +4054
Branches 2351 2810 +459
============================================
+ Hits 15838 18245 +2407
- Misses 4925 6406 +1481
- Partials 1207 1373 +166
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
The problems with MxNet is fixed by using "take" operation instead of "pick". |
Thanks for your contribution. Since all engines are fixed, we will take a look. So far LGTM |
@KexinFeng can you help to check on the NDIndex part changes? |
@patins1 It looks like |
Could you integrate it with the existing I tried to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you could add a unit test that covers BulkDataIterable
and that shows the necessity of using take
on MXNet engine, it'd be great.
params.add("mode", "wrap"); | ||
return manager.invoke("pick", new NDList(array, fullPick.getIndices()), params) | ||
return manager.invoke( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have implemented "take" for mxnet engine in PR. Could you integrate this feature with that?
Another problem is that we have been consistently distinguishing between "take" and "pick". See api/src/main/java/ai/djl/ndarray/index/dim
. Here, the usage of "take" embbed in fullPick is confusing. Could you separate "take" from fullPick?
Update:
I looked into this pr, found that currently on MXNet engine, if using the getter, the ndarray indices are still defined to be "pick" indices. This is cumbersome, but only for the purpose of backward compatibility. But there exists take
API already as mentioned above. Here in your implementation, is it possible to utilize that API, instead of creating a mixture of 'take' and 'pick' here?
api/src/main/java/ai/djl/training/dataset/BulkDataIterable.java
Outdated
Show resolved
Hide resolved
@@ -72,6 +73,7 @@ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) { | |||
"Only rank-1 indexing array is supported for pick"); | |||
} | |||
fullPick = new NDIndexFullPick(indexElem, axis); | |||
fullPick.setIndexTake(true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NDArray.take(NDArray)
returns a shape that is not useful for the functionality of this pull request:
Returns a partial NDArray pointed by index according to linear indexing, and the output is of the same shape as index.
Given ndArray
of shape e.g. (a,b,c,d)
and index
of shape (e)
, then I assumed initially that ndArray.take(index)
would return an array of shape (e,b,c,d)
. Therefore I restyled my initial implementation, using NDIndexFullTake . If anybody has a better idea, let me know, at least the test cases go green now.
78ae336
to
75d4249
Compare
@patins1 Thanks for dealing with this issue in such a timely manner! But it looks like the unittest that covers the new file api/src/main/java/ai/djl/training/dataset/BulkDataIterable.java is still missing. You mentioned you have tested it, which display efficiency increase. Could you add it into the proper unit test file? Also you mentioned
It'd be better to have unit test for this too, which will prevent future edition from breaking this fix. Thanks! |
Tests added |
@siddvenk Hi Siddarth, in this pr, we have changed the definition of get(NDArray index) from pick to take, and have given the warning. This will affect the result you mentioned in #1800. See the test in commit "add index test and clean code" 8ffac00. Pick can still be used though by addPickDim(). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is one small problem I noticed, but otherwise it looks good. I think we should just merge it for now and we can fix it in a later PR
* @param indices indices of the requested data items | ||
* @return a {@link Record} that contains the data and label of the requested data items | ||
*/ | ||
public Record getByIndices(NDManager manager, long... indices) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One convention we have in DJL to differentiate NDArrays is that a Record should be only a single index while a Batch is multiple indices. So, this method should probably be returning a Batch rather than a Record
Fixes minor issue with deepjavalibrary#1869
This pull requests optimizes Batch creation for ArrayDataset when using StackBatchifier .
My DJL application learns the data now in 58550ms rather than 79843ms so 36% faster!
For applications not using range-based indexing but indices-based indexing, my DJL application would run in 66341ms , so still 20% faster. This is true for PyTorch, i get a test failure for MxNet, so this solution maybe restricted to PyTorch