diff --git a/README.md b/README.md index e56354942..42957f844 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,12 @@ Reminder of supported compilation flags: See [nimblas](https://github.com/unicredit/nimblas) for further configuration. - `-d:cuda`: Build with Cuda support - `-d:cudnn`: Build with CuDNN support, implies `cuda`. +- `-d:avx512`: Build with AVX512 support by supplying the + `-mavx512dq` flag to gcc / clang. Without this flag the + resulting binary does not use AVX512 even on CPUs that support + it. Handing this flag however, makes the binary incompatible with + CPUs that do *not* support it. See the comments in #505 for a + discussion (from `v0.7.9`). - You might want to tune library paths in [nim.cfg](nim.cfg) after installation for OpenBLAS, MKL and Cuda compilation. The current defaults should work on Mac and Linux. diff --git a/src/arraymancer/laser/primitives/matrix_multiplication/gemm.nim b/src/arraymancer/laser/primitives/matrix_multiplication/gemm.nim index c73f50bce..cbc556632 100644 --- a/src/arraymancer/laser/primitives/matrix_multiplication/gemm.nim +++ b/src/arraymancer/laser/primitives/matrix_multiplication/gemm.nim @@ -228,25 +228,43 @@ proc gemm_strided*[T: SomeNumber and not(uint32|uint64|uint|int)]( const ukernel = cpu_features.x86_ukernel(T, false) apply(ukernel) - when defined(i386) or defined(amd64): - when T is float32: - if hasAvx512f(): dispatch(x86_AVX512) - elif hasFma3(): dispatch(x86_AVX_FMA) - elif hasAvx(): dispatch(x86_AVX) - elif hasSse(): dispatch(x86_SSE) - elif T is float64: - if hasAvx512f(): dispatch(x86_AVX512) - elif hasFma3(): dispatch(x86_AVX_FMA) - elif hasAvx(): dispatch(x86_AVX) - elif hasSse2(): dispatch(x86_SSE2) - elif T is int32: - if hasAvx512f(): dispatch(x86_AVX512) - elif hasAvx2(): dispatch(x86_AVX2) - elif hasSse41(): dispatch(x86_SSE4_1) - elif hasSse2(): dispatch(x86_SSE2) - elif T is int64: - if hasAvx512f(): dispatch(x86_AVX512) - elif hasSse2(): dispatch(x86_SSE2) + # for clarity split AVX512 compilation fully from regular + when defined(avx512): + when defined(i386) or defined(amd64): + when T is float32: + if hasAvx512f(): dispatch(x86_AVX512) + elif hasFma3(): dispatch(x86_AVX_FMA) + elif hasAvx(): dispatch(x86_AVX) + elif hasSse(): dispatch(x86_SSE) + elif T is float64: + if hasAvx512f(): dispatch(x86_AVX512) + elif hasFma3(): dispatch(x86_AVX_FMA) + elif hasAvx(): dispatch(x86_AVX) + elif hasSse2(): dispatch(x86_SSE2) + elif T is int32: + if hasAvx512f(): dispatch(x86_AVX512) + elif hasAvx2(): dispatch(x86_AVX2) + elif hasSse41(): dispatch(x86_SSE4_1) + elif hasSse2(): dispatch(x86_SSE2) + elif T is int64: + if hasAvx512f(): dispatch(x86_AVX512) + elif hasSse2(): dispatch(x86_SSE2) + else: + when defined(i386) or defined(amd64): + when T is float32: + if hasFma3(): dispatch(x86_AVX_FMA) + elif hasAvx(): dispatch(x86_AVX) + elif hasSse(): dispatch(x86_SSE) + elif T is float64: + if hasFma3(): dispatch(x86_AVX_FMA) + elif hasAvx(): dispatch(x86_AVX) + elif hasSse2(): dispatch(x86_SSE2) + elif T is int32: + if hasAvx2(): dispatch(x86_AVX2) + elif hasSse41(): dispatch(x86_SSE4_1) + elif hasSse2(): dispatch(x86_SSE2) + elif T is int64: + if hasSse2(): dispatch(x86_SSE2) dispatch(x86_Generic) proc gemm_strided*[T: uint32|uint64|uint|int]( diff --git a/src/arraymancer/laser/primitives/matrix_multiplication/gemm_ukernel_avx512.nim b/src/arraymancer/laser/primitives/matrix_multiplication/gemm_ukernel_avx512.nim index e4a137f57..dcea73cdd 100644 --- a/src/arraymancer/laser/primitives/matrix_multiplication/gemm_ukernel_avx512.nim +++ b/src/arraymancer/laser/primitives/matrix_multiplication/gemm_ukernel_avx512.nim @@ -9,6 +9,12 @@ import x86only() +## For the C codegen of AVX512 instructions to be valid, we need the following flag: +when defined(avx512) and (defined(gcc) or defined(clang)): + {.passC: "-mavx512dq".} +## See: https://stackoverflow.com/a/63711952 +## for a script to find the required compilation flags for specific SIMD functions. + ukernel_generator( x86_AVX512, typ = float32,