Skip to content

Commit

Permalink
vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and flash a…
Browse files Browse the repository at this point in the history
…ttention
  • Loading branch information
jeffbolznv committed Dec 4, 2024
1 parent 59f4db1 commit 76fff32
Show file tree
Hide file tree
Showing 6 changed files with 1,674 additions and 104 deletions.
759 changes: 676 additions & 83 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
find_package (Threads REQUIRED)
find_package(Vulkan COMPONENTS glslc REQUIRED)

set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads)
target_link_libraries(vulkan-shaders-gen PRIVATE Vulkan::Vulkan)
305 changes: 305 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@

#include "types.comp"

layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
block_q4_0_packed16 block;
};

float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1];
qs >>= shift;
qs &= 0xF;
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
block_q4_1 block;
};

float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(qs) * d + m;
return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
block_q5_0 block;
};

float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;

const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
const uint qh = ((uint_qh >> idx) << 4) & 0x10;

const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;

float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
block_q5_1 block;
};

float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const float16_t m = bl.block.m;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;

const uint uint_qh = bl.block.qh;
const uint qh = ((uint_qh >> idx) << 4) & 0x10;

const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;

float16_t ret = float16_t(qs | qh) * d + m;
return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
block_q8_0_packed16 block;
};

float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;

// Load 16b and select the byte for this element
int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
float16_t ret = float16_t(qs) * d;
return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
block_q2_K block;
};

float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const f16vec2 d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31
const uint scalesi = iqs / 16; // 0..15
const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6

uint32_t qs = bl.block.qs[qsi];
const uint scales = bl.block.scales[scalesi];
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4);
return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
block_q3_K block;
};

float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint n = iqs / 128; // 0,1
const uint qsi = n * 32 + (iqs % 32); // 0..63
const uint hmi = (iqs % 32); // 0..31
const uint j = (iqs % 128) / 8; // 0..15
const uint is = iqs / 16; // 0..15
const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
const uint qsshift = halfsplit * 2; // 0,2,4,6
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128

uint32_t scaleidx0 = (is < 8) ? is : (is-8);
uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
uint32_t scaleidx1 = is + 8 - (is/4)*4;
uint32_t scaleidx1shift = (is/4)*2;

const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));

const float16_t dl = bl.block.d * float16_t(us - 32);

float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));

return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
block_q4_K block;
};

float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint n = iqs / 64; // 0,1,2,3
const uint b = (iqs % 64) / 32; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
const uint qsi = n * 32 + (iqs % 32); // 0..127

const f16vec2 loadd = bl.block.d;

uint32_t sc;
uint32_t mbyte;

uint32_t scidx0 = (is < 4) ? is : (is + 4);
uint32_t scidx1 = (is < 4) ? is : (is - 4);
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;

sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));

const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);

uint32_t dmask = 0xF << (b * 4);

float16_t ret = d * float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) - m;

return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
block_q5_K block;
};

float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint n = iqs / 64; // 0,1,2,3
const uint b = (iqs % 64) / 32; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
const uint qsi = n * 32 + (iqs % 32); // 0..127
const uint qhi = (iqs % 32); // 0..31

const uint8_t hm = uint8_t(1 << (iqs / 32));

const f16vec2 loadd = bl.block.d;

uint32_t sc;
uint32_t mbyte;

uint32_t scidx0 = (is < 4) ? is : (is + 4);
uint32_t scidx1 = (is < 4) ? is : (is - 4);
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
uint32_t mbidx0 = is + 4;
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;

sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));

const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);

uint32_t dmask = 0xF << (b * 4);

float16_t ret = d * (float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi ] & hm) != 0 ? 16 : 0)) - m;

return ret;
}

layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
block_q6_K block;
};

float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const uint idx = coordInBlock[1];
const uint iqs = idx;

const uint n = iqs / 128; // 0,1
const uint b = (iqs % 128) / 64; // 0,1
const uint is_b = (iqs % 32) / 16; // 0,1
const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
const uint is = 8 * n + qhshift + is_b; // 0..15
const uint qsi = n * 64 + (iqs % 64); // 0..127
const uint qhi = n * 32 + (iqs % 32); // 0..63

const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);

float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi ] >> qhshift) & 3) << 4)) - 32);

return ret;
}

#if defined(DATA_A_IQ4_NL)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
block_iq4_nl block;
};

float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint iqs = idx & 0xF;
const uint shift = (idx & 0x10) >> 2;
uint32_t qs = bl.block.qs[iqs];
qs >>= shift;
qs &= 0xF;
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
return ret;
}
#endif

#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncQ4_1
#elif defined(DATA_A_Q5_0)
#define dequantFuncA dequantFuncQ5_0
#elif defined(DATA_A_Q5_1)
#define dequantFuncA dequantFuncQ5_1
#elif defined(DATA_A_Q8_0)
#define dequantFuncA dequantFuncQ8_0
#elif defined(DATA_A_Q2_K)
#define dequantFuncA dequantFuncQ2_K
#elif defined(DATA_A_Q3_K)
#define dequantFuncA dequantFuncQ3_K
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#endif
Loading

0 comments on commit 76fff32

Please sign in to comment.