diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index 164d900a1c71..80647b934cd9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -305,9 +305,9 @@ struct DistributeBroadcast final : OpDistributionPattern { auto vectorType = VectorType::get(distShape, elementType); VectorValue srcVector = dyn_cast(broadcastOp.getSource()); - // If the srcVector is a scalar (like f32) we proceed with the scalar - // distribution branch. - if (!srcVector) { + // If the srcVector is a scalar (like f32) or a rank-0 vector (like + // vector), we proceed with the scalar distribution branch. + if (!srcVector || !isNonZeroRank(srcVector)) { // The way distribution currently works, there is no partial thread // distribution, so a scalar is available to all threads. Scalar // distribution is simply a broadcast from scalar to the distributed diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index a8831809e25b..7e927b499077 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -132,14 +132,16 @@ void DistributionPattern::replaceOpWithDistributedValues( for (auto [opResult, replacement] : llvm::zip_equal(op->getOpResults(), values)) { // If this value is a vector type, it must be converted back to simd. - if (isa(replacement.getType())) { - auto oldResult = cast(opResult); - // Create a toSIMD op to convert the value back to the simd. - rewriter.setInsertionPointAfterValue(oldResult); - Value toSIMD = rewriter.create( - oldResult.getLoc(), oldResult.getType(), replacement); - // Add to replacements. - replacement = toSIMD; + if (auto replacementType = dyn_cast(replacement.getType())) { + if (replacementType.getRank() != 0) { + auto oldResult = cast(opResult); + // Create a toSIMD op to convert the value back to the simd. + rewriter.setInsertionPointAfterValue(oldResult); + Value toSIMD = rewriter.create( + oldResult.getLoc(), oldResult.getType(), replacement); + // Add to replacements. + replacement = toSIMD; + } } replacements.push_back(replacement); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir index 98455c93f3e0..71448ef84066 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir @@ -783,47 +783,6 @@ builtin.module attributes { transform.with_named_sequence } { // ----- -#layout = #iree_vector_ext.nested_layout< - subgroup_tile = [2, 2, 2], - batch_tile = [2, 2, 1], - outer_tile = [2, 1, 1], - thread_tile = [4, 16, 8], - element_tile = [1, 4, 4], - subgroup_strides = [4, 2, 1], - thread_strides = [128, 8, 1] -> - -func.func @zero_rank_broadcast(%src: vector) -> (vector<32x256x64xf16>) { - %bcast = vector.broadcast %src : vector to vector<32x256x64xf16> - %bcastl = iree_vector_ext.to_layout %bcast to layout(#layout) : vector<32x256x64xf16> - return %bcastl : vector<32x256x64xf16> -} - -builtin.module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { - %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op - transform.yield - } -} - -// CHECK-LABEL: func @zero_rank_broadcast -// CHECK-SAME: (%[[SRC:.*]]: vector) -// CHECK: %[[SRC_SIMT:.*]] = iree_vector_ext.to_simt %[[SRC]] : vector -// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC_SIMT]] -// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f16 to vector<1x4x4xf16> -// CHECK: vector.insert %[[BCAST]], %{{.*}} -// CHECK: vector.insert %[[BCAST]], %{{.*}} -// CHECK: vector.insert %[[BCAST]], %{{.*}} -// CHECK: vector.insert %[[BCAST]], %{{.*}} -// CHECK: vector.insert %[[BCAST]], %{{.*}} -// CHECK: vector.insert %[[BCAST]], %{{.*}} -// CHECK: vector.insert %[[BCAST]], %{{.*}} -// CHECK: %[[OUT:.*]] = vector.insert %[[BCAST]], %{{.*}} -// CHECK: iree_vector_ext.to_simd %[[OUT]] : vector<2x2x1x2x1x1x1x4x4xf16> -> vector<32x256x64xf16> - -// ----- - #layout = #iree_vector_ext.nested_layout< subgroup_tile = [2, 2, 2], batch_tile = [2, 2, 1], diff --git a/third_party/llvm-project b/third_party/llvm-project index 889525fa99b2..ac39504813f8 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 889525fa99b251dc962edb516e0108088ba7e44d +Subproject commit ac39504813f8c52f10c0e364485569bff5a5f7a1