Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Advanced indexing that supports all indexing features on PyTorch #1719

Merged
merged 15 commits into from
Jun 28, 2022

Conversation

KexinFeng
Copy link
Contributor

@KexinFeng KexinFeng commented Jun 16, 2022

Description

Full support of pytorch indexing.

This PR enables all indexing features in pytorch. The indexing behaviour is guaranteed to be consistent, especially when mixed indices are passed.

Demo example

See the following code for demo

// get from integer array (higher rank included) or float array
original = manager.arange(1, 7f).reshape(-1, 2);
NDArray index = manager.create(new long[] {0, 0, 1, 2}, new Shape(2, 2));
NDArray indexFloat = manager.create(new float[] {0, 0, 1, 2}, new Shape(2, 2));
NDArray actual = original.get(index);
NDArray actual2 = original.get(indexFloat);
expected = manager.create(new float[] {1, 2, 1, 2, 3, 4, 5, 6}, new Shape(2, 2, 2));
Assert.assertEquals(actual, expected);
Assert.assertEquals(actual2, expected);
// indexing with boolean, slice, and integer array (higher rank included) or float array
original = manager.arange(3 * 3 * 3 * 3).reshape(3, 3, 3, 3);
NDArray bool1 = manager.create(new boolean[] {true, false, true});
NDArray index1 = manager.create(new long[] {2, 2}, new Shape(1, 2));
NDArray index2 = manager.create(new float[] {0, 1}, new Shape(1, 2));
actual = original.get(":{}, {}, {}, {}", 2, index1, bool1, index2);
expected = manager.create(new int[] {18, 25, 45, 52}, new Shape(2, 1, 2));
Assert.assertEquals(actual, expected);
// indexing with null, slice and integer array (higher rank included) or float array
original = manager.arange(3 * 3 * 3).reshape(3, 3, 3);
index1 = manager.create(new float[] {0, 1}, new Shape(2));
index2 = manager.create(new long[] {0, 0, 2, 1}, new Shape(2, 2));
actual = original.get(":{}, {}, {}, {}", 2, index1, index2, null);
expected = manager.create(new int[] {0, 3, 2, 4, 9, 12, 11, 13}, new Shape(2, 2, 2, 1));
Assert.assertEquals(actual, expected);

The pytorch doc:

C indexing API
source

The relevant preceeding PRs:

Support of take from pytorch #1627
Add support of take on MXNet engine #1649

@KexinFeng KexinFeng changed the title Advanced indexing getter on pytorch Advanced indexing that supports all indexing features on PyTorch Jun 16, 2022
@KexinFeng KexinFeng force-pushed the torch_index branch 2 times, most recently from 07a4c27 to 82d282a Compare June 20, 2022 03:44
@KexinFeng KexinFeng marked this pull request as ready for review June 21, 2022 23:22
@KexinFeng KexinFeng requested a review from zachgk as a code owner June 21, 2022 23:22
@KexinFeng KexinFeng force-pushed the torch_index branch 2 times, most recently from 33a78d1 to 4a0196a Compare June 24, 2022 01:21
@codecov-commenter
Copy link

codecov-commenter commented Jun 25, 2022

Codecov Report

Merging #1719 (3f0c4e8) into master (bb5073f) will decrease coverage by 1.46%.
The diff coverage is 70.41%.

@@             Coverage Diff              @@
##             master    #1719      +/-   ##
============================================
- Coverage     72.08%   70.62%   -1.47%     
- Complexity     5126     5562     +436     
============================================
  Files           473      527      +54     
  Lines         21970    24691    +2721     
  Branches       2351     2680     +329     
============================================
+ Hits          15838    17438    +1600     
- Misses         4925     5922     +997     
- Partials       1207     1331     +124     
Impacted Files Coverage Δ
api/src/main/java/ai/djl/modality/cv/Image.java 69.23% <ø> (-4.11%) ⬇️
...rc/main/java/ai/djl/modality/cv/MultiBoxPrior.java 76.00% <ø> (ø)
...rc/main/java/ai/djl/modality/cv/output/Joints.java 71.42% <ø> (ø)
.../main/java/ai/djl/modality/cv/output/Landmark.java 100.00% <ø> (ø)
...main/java/ai/djl/modality/cv/output/Rectangle.java 72.41% <0.00%> (ø)
...i/djl/modality/cv/translator/BigGANTranslator.java 21.42% <ø> (-5.24%) ⬇️
...odality/cv/translator/BigGANTranslatorFactory.java 33.33% <0.00%> (+8.33%) ⬆️
...nslator/InstanceSegmentationTranslatorFactory.java 14.28% <0.00%> (-3.90%) ⬇️
.../cv/translator/StyleTransferTranslatorFactory.java 40.00% <ø> (ø)
.../ai/djl/modality/cv/translator/YoloTranslator.java 8.33% <0.00%> (-0.50%) ⬇️
... and 410 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a53acb2...3f0c4e8. Read the comment docs.

Change-Id: I5a7287719a8deedbfefa4181dc79e72d78410d49
@KexinFeng KexinFeng merged commit e69f23f into deepjavalibrary:master Jun 28, 2022
@KexinFeng KexinFeng changed the title Advanced indexing that supports all indexing features on PyTorch [pytorch] Advanced indexing that supports all indexing features on PyTorch Jun 30, 2022
KexinFeng added a commit that referenced this pull request Jul 6, 2022
…1755)

Add tensor setter feature with advanced indexing on PyTorch engine. See PR #1719
@KexinFeng KexinFeng deleted the torch_index branch August 25, 2022 00:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants