Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk Batch creation #1869

Merged
merged 12 commits into from
Sep 1, 2022
Merged

Conversation

patins1
Copy link
Contributor

@patins1 patins1 commented Aug 4, 2022

This pull requests optimizes Batch creation for ArrayDataset when using StackBatchifier .
My DJL application learns the data now in 58550ms rather than 79843ms so 36% faster!
For applications not using range-based indexing but indices-based indexing, my DJL application would run in 66341ms , so still 20% faster. This is true for PyTorch, i get a test failure for MxNet, so this solution maybe restricted to PyTorch

@codecov-commenter
Copy link

codecov-commenter commented Aug 6, 2022

Codecov Report

Merging #1869 (76364ff) into master (bb5073f) will decrease coverage by 1.98%.
The diff coverage is 68.26%.

@@             Coverage Diff              @@
##             master    #1869      +/-   ##
============================================
- Coverage     72.08%   70.10%   -1.99%     
- Complexity     5126     5867     +741     
============================================
  Files           473      576     +103     
  Lines         21970    26024    +4054     
  Branches       2351     2810     +459     
============================================
+ Hits          15838    18245    +2407     
- Misses         4925     6406    +1481     
- Partials       1207     1373     +166     
Impacted Files Coverage Δ
api/src/main/java/ai/djl/modality/cv/Image.java 69.23% <ø> (-4.11%) ⬇️
...rc/main/java/ai/djl/modality/cv/MultiBoxPrior.java 76.00% <ø> (ø)
...rc/main/java/ai/djl/modality/cv/output/Joints.java 71.42% <ø> (ø)
.../main/java/ai/djl/modality/cv/output/Landmark.java 100.00% <ø> (ø)
...main/java/ai/djl/modality/cv/output/Rectangle.java 72.41% <0.00%> (ø)
...i/djl/modality/cv/translator/BigGANTranslator.java 21.42% <0.00%> (-5.24%) ⬇️
...odality/cv/translator/BigGANTranslatorFactory.java 33.33% <0.00%> (+8.33%) ⬆️
...nslator/InstanceSegmentationTranslatorFactory.java 14.28% <0.00%> (-3.90%) ⬇️
.../cv/translator/SemanticSegmentationTranslator.java 0.00% <0.00%> (ø)
.../cv/translator/StyleTransferTranslatorFactory.java 40.00% <ø> (ø)
... and 479 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@patins1
Copy link
Contributor Author

patins1 commented Aug 7, 2022

The problems with MxNet is fixed by using "take" operation instead of "pick".
TrainMnist executes now 4 times faster for MxNet .

@lanking520
Copy link
Contributor

Thanks for your contribution. Since all engines are fixed, we will take a look. So far LGTM

@lanking520
Copy link
Contributor

@KexinFeng can you help to check on the NDIndex part changes?

@KexinFeng KexinFeng self-requested a review August 12, 2022 21:44
@KexinFeng
Copy link
Contributor

@patins1 It looks like BulkDataIterable is not covered by any unit test. You mentioned that using BulkDataIterable is more efficient. Could you add a unit test that covers this class you added?

@KexinFeng
Copy link
Contributor

KexinFeng commented Aug 25, 2022

The problems with MxNet is fixed by using "take" operation instead of "pick".
TrainMnist executes now 4 times faster for MxNet .

Could you integrate it with the existing take function in PR?

I tried to use take for MXNet engine too see here, and the PR tests all passed. But it caused the issue. Here in your edition, setting fullPick.setIndexTake(true); will trigger this problem again.

Copy link
Contributor

@KexinFeng KexinFeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could add a unit test that covers BulkDataIterable and that shows the necessity of using take on MXNet engine, it'd be great.

params.add("mode", "wrap");
return manager.invoke("pick", new NDList(array, fullPick.getIndices()), params)
return manager.invoke(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have implemented "take" for mxnet engine in PR. Could you integrate this feature with that?

Another problem is that we have been consistently distinguishing between "take" and "pick". See api/src/main/java/ai/djl/ndarray/index/dim. Here, the usage of "take" embbed in fullPick is confusing. Could you separate "take" from fullPick?

Update:
I looked into this pr, found that currently on MXNet engine, if using the getter, the ndarray indices are still defined to be "pick" indices. This is cumbersome, but only for the purpose of backward compatibility. But there exists take API already as mentioned above. Here in your implementation, is it possible to utilize that API, instead of creating a mixture of 'take' and 'pick' here?

@@ -72,6 +73,7 @@ public static Optional<NDIndexFullPick> fromIndex(NDIndex index, Shape target) {
"Only rank-1 indexing array is supported for pick");
}
fullPick = new NDIndexFullPick(indexElem, axis);
fullPick.setIndexTake(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I tried to use take for MXNet engine too see #1802, and the PR tests all passed. But it caused the #1800. Here in your edition, setting fullPick.setIndexTake(true); will trigger this problem again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NDArray.take(NDArray) returns a shape that is not useful for the functionality of this pull request:

Returns a partial NDArray pointed by index according to linear indexing, and the output is of the same shape as index.

Given ndArray of shape e.g. (a,b,c,d) and index of shape (e) , then I assumed initially that ndArray.take(index) would return an array of shape (e,b,c,d) . Therefore I restyled my initial implementation, using NDIndexFullTake . If anybody has a better idea, let me know, at least the test cases go green now.

@KexinFeng KexinFeng self-assigned this Aug 26, 2022
@KexinFeng KexinFeng force-pushed the IndicesBasedSubDataset branch from 78ae336 to 75d4249 Compare August 27, 2022 01:35
@KexinFeng
Copy link
Contributor

@patins1 Thanks for dealing with this issue in such a timely manner! But it looks like the unittest that covers the new file api/src/main/java/ai/djl/training/dataset/BulkDataIterable.java is still missing. You mentioned you have tested it, which display efficiency increase. Could you add it into the proper unit test file?

Also you mentioned

The problems with MxNet is fixed by using "take" operation instead of "pick".
TrainMnist executes now 4 times faster for MxNet .

It'd be better to have unit test for this too, which will prevent future edition from breaking this fix.

Thanks!

@patins1
Copy link
Contributor Author

patins1 commented Aug 28, 2022

Tests added

@KexinFeng KexinFeng self-requested a review August 28, 2022 07:55
@KexinFeng
Copy link
Contributor

KexinFeng commented Aug 28, 2022

@siddvenk Hi Siddarth, in this pr, we have changed the definition of get(NDArray index) from pick to take, and have given the warning. This will affect the result you mentioned in #1800. See the test in commit "add index test and clean code" 8ffac00. Pick can still be used though by addPickDim().
This change is for the purpose of making it consistent with numpy and pytorch engine, also it is more efficient as shown here.
Should we update the relevant part of the Dive into Deep Learning for Java book mentioned in #1800?

@KexinFeng KexinFeng self-requested a review August 30, 2022 18:22
Copy link
Contributor

@zachgk zachgk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one small problem I noticed, but otherwise it looks good. I think we should just merge it for now and we can fix it in a later PR

* @param indices indices of the requested data items
* @return a {@link Record} that contains the data and label of the requested data items
*/
public Record getByIndices(NDManager manager, long... indices) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One convention we have in DJL to differentiate NDArrays is that a Record should be only a single index while a Batch is multiple indices. So, this method should probably be returning a Batch rather than a Record

@zachgk zachgk merged commit e9ae8d3 into deepjavalibrary:master Sep 1, 2022
zachgk added a commit to zachgk/djl that referenced this pull request Sep 6, 2022
zachgk added a commit that referenced this pull request Sep 6, 2022
@KexinFeng KexinFeng mentioned this pull request Sep 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants