Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make byte[] vector comparisons faster! (if possible) #12621

Closed
benwtrent opened this issue Oct 4, 2023 · 11 comments
Closed

Make byte[] vector comparisons faster! (if possible) #12621

benwtrent opened this issue Oct 4, 2023 · 11 comments

Comments

@benwtrent
Copy link
Member

Description

While testing and digging around, I noticed that our float comparisons are way faster than byte on my Macbook (M1) and pretty much the same as our byte comparisons on a GCP Intel Sapphire Rapids CPU.

This seems counter-intuitive to me. I would expect Panama to be able to do more byte operations per cycle than float. My guess is the intrinsics are weird? Panama Vector just doesn't support or detect the required operations?

Here are two benchmark results using @rmuir's helpful vectorbench project:

MacBook (Apple Silicon [128bits], JDK21):

FloatDotProductBenchmark.dotProductNew     768  thrpt    5  21.781 ± 0.254  ops/us
FloatDotProductBenchmark.dotProductNew    1024  thrpt    5  15.091 ± 0.217  ops/us
BinaryDotProductBenchmark.dotProductNew     768  thrpt    5  8.041 ± 0.108  ops/us
BinaryDotProductBenchmark.dotProductNew    1024  thrpt    5  6.085 ± 0.133  ops/us

GCP (Intel Sapphire Rapids [avx512], JDK21):

FloatDotProductBenchmark.dotProductNew     768  thrpt    5  20.169 ± 0.385  ops/us
FloatDotProductBenchmark.dotProductNew    1024  thrpt    5  18.334 ± 0.180  ops/us
BinaryDotProductBenchmark.dotProductNew     768  thrpt    5  19.686 ± 0.050  ops/us
BinaryDotProductBenchmark.dotProductNew    1024  thrpt    5  14.934 ± 0.014  ops/us
cpu-flags
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflus
                                 h mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good n
                                 opl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pc
                                 id sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf
                                 _lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsba
                                 se tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx
                                  smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xs
                                 avec xgetbv1 xsaves avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vp
                                 clmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid cldemote movdiri mov
                                 dir64b fsrm md_clear serialize arch_capabilities
@rmuir
Copy link
Member

rmuir commented Oct 4, 2023

the type conversions are what makes it slow. for float case it is the equiv of:

float x = something;
float y = something;
float z = something;
// no conversions
float result = x + y * z;

for the binary case it is the equivalent of:

byte a = something;
byte b = something;
byte c = something;
// multiply b + c avoiding overflow
short z1 = (short)b * (short)c;
// add a + z1 avoiding overflow
int z = (int)a + (int)z1;

You can understand the limitations at the hardware level better by reading https://en.wikichip.org/wiki/x86/avx512_vnni
VPDPBUSD is currently not supported by openjdk either: https://bugs.openjdk.org/browse/JDK-8215891

@rmuir
Copy link
Member

rmuir commented Oct 4, 2023

Also their suggested replacement of 3 instructions for the VPDPBUSD is:

Likewise, for 8-bit values, three instructions are needed - VPMADDUBSW which is used to multiply two 8-bit pairs and add them together, followed by a VPMADDWD with the value 1 in order to simply up-convert the 16-bit values to 32-bit values, followed by the VPADDD instruction which adds the result to an accumulator.

I can tell you this is also not what is happening. We have no ability to write AVX-512-specific code and currently have to support ARM, machines with only AVX-256, etc.

@rmuir
Copy link
Member

rmuir commented Oct 4, 2023

As far as the ARM goes, the fact it has only 128-bit SIMD is the limiting factor.

For e.g. AVX-256, we use 64-bit vector of 8 byte values -> 128 bit vector of 8 short values -> 256 bit vector of 8 int values.

For ARM/NEON with only 128-bit, we can't do this as we don't have 256-bit vectors. So instead we use use 64-bit vector of 8 byte values -> 128 bit vector of 8 short values -> 2 128-bit vectors of 4 int values each. It requires splitting the vector in half, it is just all we can do.

If you want it to be faster get an ARM with SVE SIMD which has bigger vectors than NEON.

@rmuir
Copy link
Member

rmuir commented Oct 4, 2023

My recommendation: stop messing around with byte and start thinking about the new 16-bit half-float support that is present in Java 21. Unfortunately the half-float vectorization support is not even in openjdk master branch for experimentation yet. But even my 3-year old crappy $200 mobile phone can do 16-bit dot product in hardware, support is more widespread and it would avoid these issues while still saving memory/space.

@uschindler
Copy link
Contributor

Actually it is worse: Java 20 introduced conversion between short/float, but we got neither a native float16 datatype nor vector support. In short: completely unuseable. 🤮

@uschindler
Copy link
Contributor

See openjdk/jdk#9422 (Java 20)

@rmuir
Copy link
Member

rmuir commented Oct 4, 2023

Actually it is worse: Java 20 introduced conversion between short/float, but we got neither a native float16 datatype nor vector support. In short: completely unuseable.

We should at least fix the field we have in the sandbox to use it and start playing with the possibilities/performance? https://github.com/apache/lucene/blob/main/lucene/sandbox/src/java/org/apache/lucene/sandbox/document/HalfFloatPoint.java

@rmuir
Copy link
Member

rmuir commented Oct 8, 2023

@benwtrent I looked into this more and eeked a bit more out: #12632

@benwtrent
Copy link
Member Author

Thank you @rmuir && @ChrisHegarty for digging into this!

The current Panama Vector API makes doing this kind of thing frustrating. Thank y'all for wrestling with it to make us faster.

@rmuir
Copy link
Member

rmuir commented Oct 9, 2023

@benwtrent I think a big source of confusion is that while the data might be byte, the related functions return 4-byte int and 4-byte float so from a vector api perspective, they are not going to be more efficient than the float variants. E.g. there's not a performance benefit from the data being "smaller", as the very first thing we have to do is convert them up...

@rmuir
Copy link
Member

rmuir commented Oct 14, 2023

From my analysis, code being generated is correct. recommend to explore half-float instead for better performance and space tradeoffs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants