Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up Reduce operators for consecutive reduced axes #7206

Merged
merged 33 commits into from
May 5, 2021

Conversation

xadupre
Copy link
Member

@xadupre xadupre commented Apr 1, 2021

Description:
This change only improves reduce function when reduced axes are contiguous: if len(shape) == 4, any single axis is ok, axes=(0, 1) or (1, 2) or (2, 3) is ok, axes=(0, 2) is not covered by this change, former implementation prevails. In that case, the shape can be compressed into three cases: (K = axis not reduced, R = reduced axis):

  • KR - reduction on the last dimensions
  • RK - reduction on the first dimensions
  • KRK - reduction on the middle dimensions.

For these three configuration, the reduction can be optimized
with vectors operations.

One example with configuration KRK. The graph shows the ratio between an implementation and numpy implementation (numpy time / time). It gives the speed up compare to numpy.

Motivation and Context
ReduceSum is much slower than tensorflow on CPU for some configurations. This change makes it almost as fast or faster.

Current status on Speed Up

Case KRK is faster on 2 and 4 cores, case KR, RK are not faster or slower with 4 cores. The implementation relies on Eigen. Current implementation is parallelized, the new one may not be.

On Intel(R) Core(TM) i7-6500U CPU @ 2.50GHz 2.60 GHz, 2 cores.

image

fct axes N shape ReduceMax ReduceMean ReduceSum
ort (0,) 8 (8, 24, 48, 8) 24.12 3.72 2.39
ort (0,) 16 (8, 24, 48, 16) 13.31 2.02 1.72
ort (0,) 32 (8, 24, 48, 32) 7.98 3.14 1.73
ort (0,) 64 (8, 24, 48, 64) 8.93 2.05 1.45
ort (0,) 100 (8, 24, 48, 100) 7.46 1.96 1.70
ort (0,) 128 (8, 24, 48, 128) 10.06 2.29 2.34
ort (0,) 200 (8, 24, 48, 200) 8.76 1.76 1.97
ort (0,) 256 (8, 24, 48, 256) 6.54 1.89 1.41
ort (0,) 400 (8, 24, 48, 400) 6.04 2.21 1.50
ort (0,) 512 (8, 24, 48, 512) 6.47 1.65 1.52
ort (0,) 1024 (8, 24, 48, 1024) 6.20 1.64 1.46
ort (1, 2) 8 (8, 24, 48, 8) 4.24 6.39 5.10
ort (1, 2) 16 (8, 24, 48, 16) 2.62 4.64 3.32
ort (1, 2) 32 (8, 24, 48, 32) 3.61 4.11 3.46
ort (1, 2) 64 (8, 24, 48, 64) 1.83 6.99 5.13
ort (1, 2) 100 (8, 24, 48, 100) 2.17 6.05 4.55
ort (1, 2) 128 (8, 24, 48, 128) 3.01 9.59 6.13
ort (1, 2) 200 (8, 24, 48, 200) 2.53 4.48 4.20
ort (1, 2) 256 (8, 24, 48, 256) 2.73 8.00 6.63
ort (1, 2) 400 (8, 24, 48, 400) 2.49 4.98 3.96
ort (1, 2) 512 (8, 24, 48, 512) 2.21 10.03 9.18
ort (1, 2) 1024 (8, 24, 48, 1024) 3.55 23.55 18.84
ort (1,) 8 (8, 1152, 8) 4.50 3.16 4.84
ort (1,) 16 (8, 1152, 16) 2.94 3.67 2.75
ort (1,) 32 (8, 1152, 32) 3.43 3.93 3.60
ort (1,) 64 (8, 1152, 64) 3.37 4.57 5.27
ort (1,) 100 (8, 1152, 100) 2.26 5.24 3.74
ort (1,) 128 (8, 1152, 128) 2.89 6.85 6.93
ort (1,) 200 (8, 1152, 200) 2.48 4.31 3.96
ort (1,) 256 (8, 1152, 256) 2.65 6.74 6.89
ort (1,) 400 (8, 1152, 400) 2.71 4.18 3.52
ort (1,) 512 (8, 1152, 512) 2.81 9.59 8.99
ort (1,) 1024 (8, 1152, 1024) 2.91 25.48 23.39
ort (2, 3) 8 (2, 8, 12, 24, 2, 8) 4.99 5.66 4.52
ort (2, 3) 16 (2, 8, 12, 24, 2, 16) 2.51 2.29 3.36
ort (2, 3) 32 (2, 8, 12, 24, 2, 32) 2.46 4.04 3.09
ort (2, 3) 64 (2, 8, 12, 24, 2, 64) 2.09 4.77 3.43
ort (2, 3) 100 (2, 8, 12, 24, 2, 100) 2.64 4.40 3.95
ort (2, 3) 128 (2, 8, 12, 24, 2, 128) 3.27 6.27 3.78
ort (2, 3) 200 (2, 8, 12, 24, 2, 200) 2.74 3.83 4.10
ort (2, 3) 256 (2, 8, 12, 24, 2, 256) 2.78 5.95 6.13
ort (2, 3) 400 (2, 8, 12, 24, 2, 400) 2.40 4.60 3.72
ort (2, 3) 512 (2, 8, 12, 24, 2, 512) 2.18 5.91 5.66
ort (2, 3) 1024 (2, 8, 12, 24, 2, 1024) 2.20 7.25 6.20
ort (3,) 8 (8, 24, 48, 8) 5.59 2.02 2.55
ort (3,) 16 (8, 24, 48, 16) 2.73 1.49 1.00
ort (3,) 32 (8, 24, 48, 32) 3.52 1.85 1.81
ort (3,) 64 (8, 24, 48, 64) 5.12 2.46 1.77
ort (3,) 100 (8, 24, 48, 100) 4.13 7.53 3.10
ort (3,) 128 (8, 24, 48, 128) 4.43 3.21 4.15
ort (3,) 200 (8, 24, 48, 200) 4.18 3.85 3.46
ort (3,) 256 (8, 24, 48, 256) 4.08 4.03 3.60
ort (3,) 400 (8, 24, 48, 400) 3.82 3.39 3.56
ort (3,) 512 (8, 24, 48, 512) 3.99 4.45 3.60
ort (3,) 1024 (8, 24, 48, 1024) 3.46 4.74 4.04

On Intel(R) Core(TM) i7-8650U CPU @ 1.90GHz, 4 cores.

fct axes N shape ReduceMax ReduceMean ReduceSum
ort (0,) 8 (8, 24, 48, 8) 9.72 2.71 1.11
ort (0,) 16 (8, 24, 48, 16) 3.39 1.34 1.89
ort (0,) 32 (8, 24, 48, 32) 2.98 1.10 1.61
ort (0,) 64 (8, 24, 48, 64) 3.09 1.18 1.55
ort (0,) 100 (8, 24, 48, 100) 3.27 0.95 1.07
ort (0,) 128 (8, 24, 48, 128) 2.46 1.32 1.66
ort (0,) 200 (8, 24, 48, 200) 2.39 0.85 1.09
ort (0,) 256 (8, 24, 48, 256) 1.85 1.08 1.43
ort (0,) 400 (8, 24, 48, 400) 2.22 1.01 1.21
ort (0,) 512 (8, 24, 48, 512) 1.87 1.34 1.40
ort (0,) 1024 (8, 24, 48, 1024) 1.93 1.03 1.25
ort (0,) 2048 (8, 24, 48, 2048) 1.81 0.93 1.16
ort (0,) 2572 (8, 24, 48, 2572) 2.42 0.76 0.94
ort (1, 2) 8 (8, 24, 48, 8) 4.69 2.71 4.90
ort (1, 2) 16 (8, 24, 48, 16) 2.34 2.33 1.46
ort (1, 2) 32 (8, 24, 48, 32) 2.09 2.05 2.59
ort (1, 2) 64 (8, 24, 48, 64) 1.97 1.85 2.35
ort (1, 2) 100 (8, 24, 48, 100) 1.70 2.24 1.76
ort (1, 2) 128 (8, 24, 48, 128) 1.97 2.67 2.95
ort (1, 2) 200 (8, 24, 48, 200) 1.69 1.77 1.72
ort (1, 2) 256 (8, 24, 48, 256) 2.35 2.73 2.66
ort (1, 2) 400 (8, 24, 48, 400) 1.50 1.64 1.71
ort (1, 2) 512 (8, 24, 48, 512) 2.03 3.86 4.64
ort (1, 2) 1024 (8, 24, 48, 1024) 1.22 4.56 4.85
ort (1, 2) 2048 (8, 24, 48, 2048) 2.77 22.89 26.05
ort (1, 2) 2572 (8, 24, 48, 2572) 1.24 1.99 1.95
ort (1,) 8 (8, 1152, 8) 6.04 3.40 3.08
ort (1,) 16 (8, 1152, 16) 3.06 2.45 2.17
ort (1,) 32 (8, 1152, 32) 1.92 1.98 1.93
ort (1,) 64 (8, 1152, 64) 1.69 2.17 2.39
ort (1,) 100 (8, 1152, 100) 1.54 1.90 2.06
ort (1,) 128 (8, 1152, 128) 2.36 2.57 2.81
ort (1,) 200 (8, 1152, 200) 1.60 1.67 1.81
ort (1,) 256 (8, 1152, 256) 2.12 2.55 2.74
ort (1,) 400 (8, 1152, 400) 1.39 1.63 1.66
ort (1,) 512 (8, 1152, 512) 1.95 3.73 3.94
ort (1,) 1024 (8, 1152, 1024) 1.13 5.39 5.40
ort (1,) 2048 (8, 1152, 2048) 2.91 23.02 18.71
ort (1,) 2572 (8, 1152, 2572) 1.30 2.03 2.00
ort (2, 3) 8 (2, 8, 12, 24, 2, 8) 5.09 2.85 2.78
ort (2, 3) 16 (2, 8, 12, 24, 2, 16) 2.20 1.44 2.21
ort (2, 3) 32 (2, 8, 12, 24, 2, 32) 2.58 1.68 1.82
ort (2, 3) 64 (2, 8, 12, 24, 2, 64) 2.11 1.82 1.69
ort (2, 3) 100 (2, 8, 12, 24, 2, 100) 2.32 1.53 1.63
ort (2, 3) 128 (2, 8, 12, 24, 2, 128) 2.13 1.41 1.57
ort (2, 3) 200 (2, 8, 12, 24, 2, 200) 1.57 1.67 1.55
ort (2, 3) 256 (2, 8, 12, 24, 2, 256) 2.59 2.56 2.49
ort (2, 3) 400 (2, 8, 12, 24, 2, 400) 1.22 1.92 1.59
ort (2, 3) 512 (2, 8, 12, 24, 2, 512) 1.50 2.24 2.27
ort (2, 3) 1024 (2, 8, 12, 24, 2, 1024) 1.31 2.33 2.55
ort (2, 3) 2048 (2, 8, 12, 24, 2, 2048) 1.36 2.92 3.32
ort (2, 3) 2572 (2, 8, 12, 24, 2, 2572) 1.14 2.01 1.89
ort (3,) 8 (8, 24, 48, 8) 4.00 1.71 1.66
ort (3,) 16 (8, 24, 48, 16) 1.18 1.08 1.24
ort (3,) 32 (8, 24, 48, 32) 1.81 1.01 1.22
ort (3,) 64 (8, 24, 48, 64) 1.44 1.09 1.32
ort (3,) 100 (8, 24, 48, 100) 1.25 1.19 1.36
ort (3,) 128 (8, 24, 48, 128) 1.56 1.17 1.21
ort (3,) 200 (8, 24, 48, 200) 0.94 1.09 1.09
ort (3,) 256 (8, 24, 48, 256) 1.07 1.17 1.14
ort (3,) 400 (8, 24, 48, 400) 0.91 1.19 1.16
ort (3,) 512 (8, 24, 48, 512) 1.46 1.29 1.28
ort (3,) 1024 (8, 24, 48, 1024) 1.02 1.33 1.29
ort (3,) 2048 (8, 24, 48, 2048) 0.96 1.33 1.30
ort (3,) 2572 (8, 24, 48, 2572) 1.00 1.27 1.28

@xadupre xadupre requested a review from a team as a code owner April 1, 2021 09:43
@fdwr
Copy link
Contributor

fdwr commented Apr 2, 2021

Nice boost. (the DML EP also does this inside DirectML.dll, flattening any adjacent axes that are stride contiguous)

@xadupre
Copy link
Member Author

xadupre commented Apr 6, 2021

flattening any adjacent axes that are stride contiguous

I did not know. It seemed a good idea to reduce the cost of going to the next element to sum up. I made a modification to parallelize the case KR (reduce the last dimension). It is faster. I'm working on the last one RK (reduce the first dimension). Still parallelizing.

@xadupre xadupre changed the title [WIP] Speed up Reduce operators for consecutive reduced axes Speed up Reduce operators for consecutive reduced axes Apr 7, 2021
@xadupre
Copy link
Member Author

xadupre commented Apr 20, 2021

There's a bit of inconsistency with this class now, and probably an excessive contribution to binary size due to its usage.

About the binary size, should I exclude this from the minimal build? For the rest, I'll think of a better design.

@skottmckay
Copy link
Contributor

How much growth is there from the change? Easiest way to compare is https://osgwiki.com/wiki/SizeBench with before/after RelWithDebInfo builds.


In reply to: 823128901

@xadupre
Copy link
Member Author

xadupre commented Apr 26, 2021

There's a bit of inconsistency with this class now, and probably an excessive contribution to binary size due to its usage.

I started to refactor this optimization and the previous code to reduce the binary size. I'll measure it.

@xadupre
Copy link
Member Author

xadupre commented Apr 27, 2021

How much growth is there from the change? Easiest way to compare is https://osgwiki.com/wiki/SizeBench with before/after RelWithDebInfo builds

Here is the current status:

image

@skottmckay
Copy link
Contributor

What's the before/after? Is it master vs latest changes providing an overall reduction of ~20KB despite adding the new fast reduce logic?


In reply to: 827559533

Copy link
Contributor

@skottmckay skottmckay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@xadupre
Copy link
Member Author

xadupre commented May 4, 2021

What's the before/after? Is it master vs latest changes providing an overall reduction of ~20KB despite adding the new fast reduce logic?

In reply to: 827559533

Yes, it means the gain would be even bigger without the new logic. I did not exclude it from the minimal build but I can if you think it is worth doing it.

@xadupre xadupre merged commit ade6ed5 into microsoft:master May 5, 2021
@xadupre xadupre deleted the red branch May 7, 2021 15:23
xadupre added a commit that referenced this pull request May 24, 2021
…#7719)

* improves ArgMin implementation
* update  parallelization cost
* choose former implementation for KRK case, when K=1
* improves unit test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants