Skip to content

Commit

Permalink
Width Packing Mat1 input for Quantized Linear (#6149)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #6149

Width packing mat1 input for Quantized Linear as ASR model provides channel-packed matrix while operator does not support channel-packed yet.

Reviewed By: nathanaelsee, jorgep31415

Differential Revision: D64065606

fbshipit-source-id: 2a7d43d432deef7245d1d45f5c760b0f42627551
  • Loading branch information
Kush Rastogi authored and facebook-github-bot committed Oct 15, 2024
1 parent 517fddb commit e342a92
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,46 +71,63 @@ void add_q_8w_linear_node(
const ValueRef q_mat2_data,
const ValueRef scales_data,
const ValueRef out) {
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
ValueRef mat1_W_packed = mat1;
ValueRef out_W_packed = out;
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
// Ensure mat1 is width packed
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
// Ensure out is packed correctly
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
}
ValueRef q_mat2 =
prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked);
ValueRef scales =
prepack_if_tensor_ref(graph, scales_data, utils::kWidthPacked);

std::string kernel_name = "q_8w_linear";
kernel_name.reserve(kShaderNameReserve);
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1));
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
add_dtype_suffix(kernel_name, graph.dtype_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));

vkapi::ParamsBindList ubos({});
if (graph.is_buffer_storage(out)) {
if (graph.is_buffer_storage(out_W_packed)) {
ubos.append(
{graph.sizes_ubo(out),
graph.strides_ubo(out),
graph.numel_ubo(out),
graph.sizes_ubo(mat1),
{graph.sizes_ubo(out_W_packed),
graph.strides_ubo(out_W_packed),
graph.numel_ubo(out_W_packed),
graph.sizes_ubo(mat1_W_packed),
graph.strides_ubo(mat1),
graph.strides_ubo(q_mat2),
graph.strides_ubo(scales)});
} else {
ubos.append({graph.logical_limits_ubo(out), graph.sizes_ubo(mat1)});
ubos.append(
{graph.logical_limits_ubo(out_W_packed),
graph.sizes_ubo(mat1_W_packed)});
}

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
graph.create_global_wg_size(out_W_packed),
graph.create_local_wg_size(out_W_packed),
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{mat1, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
ubos,
// Specialization Constants
{},
// Resizing Logic
resize_qlinear_node));
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(out) != WHCN::kWidthDim) {
viewFn(graph, {out_W_packed, graph.add_none(), out});
}
}

void weight_int8pack_mm(
Expand Down

0 comments on commit e342a92

Please sign in to comment.