diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp index 61c325c8690a..a11935114eba 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp @@ -787,6 +787,19 @@ LogicalResult setScatterLoweringConfig(IREE::GPU::TargetAttr target, } } + int64_t numBatch = scatter.getBatchRank(); + // Currently bufferization will fail if the only dimension distributed to + // workgroups is the batch dims because the workgroup level slice will fold + // away and cause a mismatch. + // TODO(qedawkins): Support this case. + if (llvm::all_of_zip(llvm::drop_begin(workgroupTileSizes, numBatch), + llvm::drop_begin(loopBounds, numBatch), + [](int64_t tileSize, int64_t bound) { + return tileSize == bound || tileSize == 0; + })) { + return failure(); + } + // Attach the MMA schedule as an attribute to the entry point export function // for later access in the pipeline. MLIRContext *context = scatter.getContext(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir index 3d137e5e79d6..6f94069f2c6d 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir @@ -371,11 +371,10 @@ func.func @only_scattered_dim(%arg0: tensor<48xf32>, } // CHECK-LABEL: func.func @only_scattered_dim -// CHECK-SAME: #iree_codegen.translation_info