Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: optimize and reenable split_k #10637

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ struct vk_device_struct {
vk_queue transfer_queue;
bool single_queue;
uint32_t subgroup_size;
uint32_t shader_core_count;
bool uma;

size_t idx;
Expand Down Expand Up @@ -1498,7 +1499,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
Expand Down Expand Up @@ -1610,23 +1611,36 @@ static vk_device ggml_vk_get_device(size_t idx) {
const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();

bool maintenance4_support = false;
bool sm_builtins = false;

// Check if maintenance4 is supported
for (const auto& properties : ext_props) {
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
maintenance4_support = true;
} else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
sm_builtins = true;
}
}

vk::PhysicalDeviceProperties2 props2;
vk::PhysicalDeviceMaintenance3Properties props3;
vk::PhysicalDeviceMaintenance4Properties props4;
vk::PhysicalDeviceSubgroupProperties subgroup_props;
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
props2.pNext = &props3;
props3.pNext = &subgroup_props;

VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&subgroup_props;

if (maintenance4_support) {
subgroup_props.pNext = &props4;
last_struct->pNext = (VkBaseOutStructure *)&props4;
last_struct = (VkBaseOutStructure *)&props4;
}
if (sm_builtins) {
last_struct->pNext = (VkBaseOutStructure *)&sm_props;
last_struct = (VkBaseOutStructure *)&sm_props;
}

device->physical_device.getProperties2(&props2);
device->properties = props2.properties;

Expand All @@ -1643,6 +1657,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->vendor_id = device->properties.vendorID;
device->subgroup_size = subgroup_props.subgroupSize;
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
if (sm_builtins) {
device->shader_core_count = sm_props.shaderSMCount;
} else {
device->shader_core_count = 0;
}

bool fp16_storage = false;
bool fp16_compute = false;
Expand Down Expand Up @@ -2732,15 +2751,25 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
dst->device->device.resetFences({ dst->device->fence });
}

static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
// if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
// return 4;
// }

return 1;
uint32_t split_k = 1;
if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
// If k is 'large' and the SMs will fill less than halfway, use split_k.
uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
// Clamp to 2 or 4
split_k = std::min(split_k, 4u);
if (split_k == 3) {
split_k = 2;
}
}
}

GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
return split_k;
}

static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
Expand Down Expand Up @@ -2964,10 +2993,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;

const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);

vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);

const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);

const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
Expand All @@ -2993,7 +3022,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
if (dryrun) {
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
if (
(qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
(qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
Expand Down
31 changes: 25 additions & 6 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,44 @@
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};

layout (push_constant) uniform parameter {
uint ne;
uint k_num;
} p;

void main() {
const uint idx = gl_GlobalInvocationID.x;
// Each invocation handles four consecutive components
const uint idx = gl_GlobalInvocationID.x * 4;

if (idx >= p.ne) {
return;
}

float result = 0.0f;
// Check if all four components are in bounds and aligned,
// then use vector loads
if (idx + 3 < p.ne && (p.ne % 4) == 0) {
vec4 result = vec4(0.0f);

[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a[i * p.ne + idx];
}
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a4[(i * p.ne + idx) / 4];
}

data_d4[idx / 4] = result;
} else {
[[unroll]] for (uint j = 0; j < 4; ++j) {
if (idx + j < p.ne) {
float result = 0.0f;

data_d[idx] = result;
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a[i * p.ne + idx + j];
}

data_d[idx + j] = result;
}
}
}
}
Loading