Skip to content

Commit

Permalink
Add test for zero-rank broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
Groverkss committed Nov 11, 2024
1 parent 8848fd8 commit 6bef0ea
Showing 1 changed file with 41 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,47 @@ builtin.module attributes { transform.with_named_sequence } {

// -----

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [2, 2, 2],
batch_tile = [2, 2, 1],
outer_tile = [2, 1, 1],
thread_tile = [4, 16, 8],
element_tile = [1, 4, 4],
subgroup_strides = [4, 2, 1],
thread_strides = [128, 8, 1]
>

func.func @zero_rank_broadcast(%src: vector<f16>) -> (vector<32x256x64xf16>) {
%bcast = vector.broadcast %src : vector<f16> to vector<32x256x64xf16>
%bcastl = iree_vector_ext.to_layout %bcast to layout(#layout) : vector<32x256x64xf16>
return %bcastl : vector<32x256x64xf16>
}

builtin.module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) {
%top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op
transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func @zero_rank_broadcast
// CHECK-SAME: (%[[SRC:.*]]: vector<f16>)
// CHECK: %[[SRC_SIMT:.*]] = iree_vector_ext.to_simt %[[SRC]] : vector<f16>
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC_SIMT]]
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f16 to vector<1x4x4xf16>
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: vector.insert %[[BCAST]], %{{.*}}
// CHECK: %[[OUT:.*]] = vector.insert %[[BCAST]], %{{.*}}
// CHECK: iree_vector_ext.to_simd %[[OUT]] : vector<2x2x1x2x1x1x1x4x4xf16> -> vector<32x256x64xf16>

// -----

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [2, 2, 2],
batch_tile = [2, 2, 1],
Expand Down

0 comments on commit 6bef0ea

Please sign in to comment.