-
Notifications
You must be signed in to change notification settings - Fork 11k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
vulkan: implement initial support for IQ2 and IQ3 quantizations (#11360)
* vulkan: initial support for IQ3_S * vulkan: initial support for IQ3_XXS * vulkan: initial support for IQ2_XXS * vulkan: initial support for IQ2_XS * vulkan: optimize Q3_K by removing branches * vulkan: implement dequantize variants for coopmat2 * vulkan: initial support for IQ2_S * vulkan: vertically realign code * port failing dequant callbacks from mul_mm * Fix array length mismatches * vulkan: avoid using workgroup size before it is referenced * tests: increase timeout for Vulkan llvmpipe backend --------- Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
- gg-ci-fix-arm-b4760-f343850
- b4856
- b4855
- b4854
- b4853
- b4851
- b4849
- b4848
- b4847
- b4846
- b4837
- b4836
- b4835
- b4834
- b4833
- b4832
- b4831
- b4830
- b4829
- b4827
- b4826
- b4824
- b4823
- b4821
- b4820
- b4819
- b4818
- b4806
- b4805
- b4804
- b4803
- b4801
- b4800
- b4799
- b4798
- b4797
- b4796
- b4793
- b4792
- b4790
- b4789
- b4788
- b4786
- b4785
- b4784
- b4783
- b4778
- b4777
- b4776
- b4775
- b4774
- b4773
- b4771
- b4770
- b4769
- b4768
- b4767
- b4765
- b4764
- b4763
- b4762
- b4761
- b4760
- b4759
- b4756
- b4755
- b4754
- b4753
- b4751
- b4749
- b4747
- b4746
- b4745
- b4743
- b4742
- b4739
- b4738
- b4735
- b4734
- b4733
- b4732
- b4731
- b4730
- b4728
- b4727
- b4724
- b4722
- b4721
- b4720
- b4719
- b4718
- b4717
- b4716
- b4714
- b4713
- b4712
- b4710
- b4708
- b4707
- b4706
- b4705
- b4704
- b4702
- b4699
- b4698
- b4696
- b4695
- b4694
- b4692
- b4689
- b4688
- b4686
- b4683
- b4682
- b4681
- b4679
- b4678
- b4677
- b4676
- b4675
- b4671
- b4667
- b4666
- b4663
- b4662
- b4661
- b4660
- b4659
- b4658
- b4657
- b4651
- b4649
- b4648
- b4647
- b4646
- b4644
- b4643
- b4642
- b4641
- b4640
- b4639
- b4637
- b4636
- b4634
- b4633
- b4631
- b4628
- b4623
- b4621
- b4620
- b4619
- b4618
- b4617
- b4616
- b4615
- b4614
- b4613
- b4611
- b4610
- b4609
- b4608
- b4607
- b4606
- b4605
- b4604
- b4603
- b4601
- b4600
- b4599
- b4598
- b4595
- b4594
- b4589
- b4588
1 parent
e51c47b
commit 66ee4f2
Showing
19 changed files
with
1,616 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#version 450 | ||
|
||
#include "dequant_head.comp" | ||
|
||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer A {block_iq2_s data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; | ||
|
||
void main() { | ||
// Each thread handles 1 subblock (32 values with 2 scales) | ||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; | ||
|
||
init_iq_shmem(gl_WorkGroupSize); | ||
|
||
if (ib >= p.nel / 256) { | ||
return; | ||
} | ||
|
||
const uint ib32 = gl_LocalInvocationID.x % 8; | ||
const uint b_idx = 256 * ib + 32 * ib32; | ||
|
||
const float d = float(data_a[ib].d); | ||
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); | ||
const vec2 db = d * (0.5 + scale) * 0.25; | ||
|
||
uint qh = data_a[ib].qh[ib32]; | ||
[[unroll]] for (uint l = 0; l < 4; ++l) { | ||
uint qs = data_a[ib].qs[4 * ib32 + l]; | ||
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; | ||
qs |= (qh << (8 - 2 * l)) & 0x300; | ||
const uvec2 grid = iq2s_grid[qs & 511]; | ||
const u8vec4 grid0 = unpack8(grid.x); | ||
const u8vec4 grid1 = unpack8(grid.y); | ||
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#version 450 | ||
|
||
#include "dequant_head.comp" | ||
|
||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; | ||
|
||
void main() { | ||
// Each thread handles 1 subblock (32 values with 2 scales) | ||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; | ||
|
||
init_iq_shmem(gl_WorkGroupSize); | ||
|
||
if (ib >= p.nel / 256) { | ||
return; | ||
} | ||
|
||
const uint ib32 = gl_LocalInvocationID.x % 8; | ||
const uint b_idx = 256 * ib + 32 * ib32; | ||
|
||
const float d = float(data_a[ib].d); | ||
const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); | ||
const vec2 db = d * (0.5 + scale) * 0.25; | ||
|
||
[[unroll]] for (uint l = 0; l < 4; ++l) { | ||
uint16_t qs = data_a[ib].qs[4 * ib32 + l]; | ||
const uint sign7 = qs >> 9; | ||
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit | ||
const uvec2 grid = iq2xs_grid[qs & 511]; | ||
const u8vec4 grid0 = unpack8(grid.x); | ||
const u8vec4 grid1 = unpack8(grid.y); | ||
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#version 450 | ||
|
||
#include "dequant_head.comp" | ||
|
||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; | ||
|
||
void main() { | ||
// Each thread handles 1 scale block (32 values) | ||
// Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits | ||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; | ||
|
||
init_iq_shmem(gl_WorkGroupSize); | ||
|
||
if (ib >= p.nel / 256) { | ||
return; | ||
} | ||
|
||
const uint is = gl_LocalInvocationID.x % 8; | ||
const uint b_idx = 256 * ib + 32 * is; | ||
|
||
const float d = float(data_a[ib].d); | ||
uint signscale = pack32(u8vec4( | ||
data_a[ib].qs[8*is + 4], | ||
data_a[ib].qs[8*is + 5], | ||
data_a[ib].qs[8*is + 6], | ||
data_a[ib].qs[8*is + 7] | ||
)); | ||
const float db = d * (0.5 + (signscale >> 28)) * 0.25; | ||
|
||
[[unroll]] for (uint l = 0; l < 4; ++l) { | ||
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); | ||
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit | ||
const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]]; | ||
const u8vec4 grid0 = unpack8(grid.x); | ||
const u8vec4 grid1 = unpack8(grid.y); | ||
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#version 450 | ||
|
||
#include "dequant_head.comp" | ||
|
||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer A {block_iq3_s data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; | ||
|
||
void main() { | ||
// Each thread handles 1 scale nibble. | ||
// Each block contains 4 scale bytes (8 scales) for 256 output values. | ||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; | ||
|
||
init_iq_shmem(gl_WorkGroupSize); | ||
|
||
if (ib >= p.nel / 256) { | ||
return; | ||
} | ||
|
||
const uint is = gl_LocalInvocationID.x % 8; | ||
const uint b_idx = 256 * ib + 32 * is; | ||
|
||
const float d = float(data_a[ib].d); | ||
const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf)); | ||
|
||
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. | ||
uint qh = data_a[ib].qh[is]; | ||
[[unroll]] for (uint l = 0; l < 8; ++l) { | ||
uint qs = data_a[ib].qs[8 * is + l]; | ||
uint gidx = qs | ((qh << (8 - l)) & 256); | ||
uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1)); | ||
u8vec4 grid = unpack8(iq3s_grid[gidx]); | ||
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#version 450 | ||
|
||
#include "dequant_head.comp" | ||
|
||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];}; | ||
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; | ||
|
||
void main() { | ||
// Each thread handles 1 scale block (32 values) | ||
// 8 threads handle 1 superblock | ||
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; | ||
|
||
init_iq_shmem(gl_WorkGroupSize); | ||
|
||
if (ib >= p.nel / 256) { | ||
return; | ||
} | ||
|
||
const uint is = gl_LocalInvocationID.x % 8; | ||
const uint b_idx = 256 * ib + 32 * is; | ||
const uint s_idx = QUANT_K / 4 + 4 * is; | ||
|
||
const float d = float(data_a[ib].d); | ||
uint signscale = pack32(u8vec4( | ||
data_a[ib].qs[s_idx + 0], | ||
data_a[ib].qs[s_idx + 1], | ||
data_a[ib].qs[s_idx + 2], | ||
data_a[ib].qs[s_idx + 3] | ||
)); | ||
const float db = d * (0.5 + (signscale >> 28)) * 0.5; | ||
|
||
[[unroll]] for (uint l = 0; l < 4; ++l) { | ||
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); | ||
// Restore parity bit. | ||
const uint sign8 = sign7 | (bitCount(sign7) << 7); | ||
const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]); | ||
const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]); | ||
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); | ||
data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters