From c1cc4ccbc1b58c7bcf85bed1f43cff743d941d8e Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 17 Jan 2025 10:22:11 -0500 Subject: [PATCH] [LLVMGPU] Add pass to distribute undistributed copies to threads (#19715) This pass walks a function and distributes any memref copies not present within an scf.forall distributed to threads/warps/lanes. This pass assumes that implicit distribution (a la gpu.thread_id) is not present. --- .../compiler/Codegen/Common/GPU/BUILD.bazel | 1 + .../Codegen/Common/GPU/CMakeLists.txt | 1 + .../GPU/GPUDistributeCopyUsingForall.cpp | 171 ++++++++++++++++++ .../compiler/Codegen/Common/GPU/Passes.td | 8 + .../Codegen/Common/GPU/test/BUILD.bazel | 1 + .../Codegen/Common/GPU/test/CMakeLists.txt | 1 + .../gpu_distribute_copy_using_forall.mlir | 99 ++++++++++ .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 1 + 8 files changed, 283 insertions(+) create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeCopyUsingForall.cpp create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_copy_using_forall.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index c35b5115d2fe..a2b98a57b601 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -56,6 +56,7 @@ iree_compiler_cc_library( "GPUCombineValueBarriers.cpp", "GPUCreateFastSlowPath.cpp", "GPUDistribute.cpp", + "GPUDistributeCopyUsingForall.cpp", "GPUDistributeForall.cpp", "GPUDistributeScfFor.cpp", "GPUDistributeSharedMemoryCopy.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt index 51576eb38295..df7015cde5f5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt @@ -54,6 +54,7 @@ iree_cc_library( "GPUCombineValueBarriers.cpp" "GPUCreateFastSlowPath.cpp" "GPUDistribute.cpp" + "GPUDistributeCopyUsingForall.cpp" "GPUDistributeForall.cpp" "GPUDistributeScfFor.cpp" "GPUDistributeSharedMemoryCopy.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeCopyUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeCopyUsingForall.cpp new file mode 100644 index 000000000000..85099dfc0e9c --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeCopyUsingForall.cpp @@ -0,0 +1,171 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h" +#include "iree/compiler/Codegen/Utils/GPUUtils.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" + +#define DEBUG_TYPE "iree-codegen-gpu-distribute-shared-memory-copy" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_GPUDISTRIBUTECOPYUSINGFORALLPASS +#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc" + +namespace { +//====---------------------------------------------------------------------===// +// Pass to lower workgroup memory copy to distibuted +// transfer_read/transfer_write ops. +//====---------------------------------------------------------------------===// + +// For optimal performance we always want to copy 128 bits +static constexpr int kPreferredCopyNumBits = 128; + +// Moves the copy into a single threaded forall. +static void distributeCopyToSingleThread(RewriterBase &rewriter, + memref::CopyOp copy) { + SmallVector mapping = {gpu::GPUThreadMappingAttr::get( + rewriter.getContext(), gpu::MappingId::LinearDim0)}; + scf::ForallOp newForallOp = rewriter.create( + copy.getLoc(), ArrayRef{rewriter.getIndexAttr(0)}, + ArrayRef{rewriter.getIndexAttr(1)}, + ArrayRef{rewriter.getIndexAttr(1)}, + /*outputs=*/ValueRange(), /*mapping=*/rewriter.getArrayAttr(mapping)); + rewriter.moveOpBefore(copy, newForallOp.getBody(), + newForallOp.getBody()->begin()); +} + +/// Distributes a copy with a thread mapping. +static void distributeCopyToThreads(RewriterBase &rewriter, memref::CopyOp copy, + ArrayRef tileSizes) { + int64_t rank = tileSizes.size(); + assert(rank == copy.getTarget().getType().getRank() && + "tile size and copy rank mismatch"); + if (rank == 0) { + distributeCopyToSingleThread(rewriter, copy); + return; + } + + Location loc = copy.getLoc(); + MLIRContext *context = rewriter.getContext(); + + SmallVector lowerBounds(rank, rewriter.getIndexAttr(0)); + SmallVector upperBounds = + memref::getMixedSizes(rewriter, loc, copy.getSource()); + + SmallVector mapping; + int idx = 0; + for (int64_t i = 0, e = rank; i < e; ++i) { + unsigned mappingId = + static_cast(gpu::MappingId::LinearDim0) + idx++; + mapping.push_back(gpu::GPUThreadMappingAttr::get( + context, static_cast(mappingId))); + } + mapping = llvm::to_vector(llvm::reverse(mapping)); + + scf::ForallOp newForallOp = rewriter.create( + copy.getLoc(), lowerBounds, upperBounds, tileSizes, + /*outputs=*/ValueRange(), /*mapping=*/rewriter.getArrayAttr(mapping)); + + rewriter.setInsertionPointToStart(newForallOp.getBody()); + + AffineExpr d0, d1, d2; + bindDims(context, d0, d1, d2); + SmallVector sizes; + AffineMap minMap = + AffineMap::get(/*dimCount=*/3, /*symbolCount=*/0, {d0, d1 - d2}, context); + for (auto [ub, tileSize, iterator] : llvm::zip_equal( + upperBounds, tileSizes, newForallOp.getInductionVars())) { + std::optional staticUb = getConstantIntValue(ub); + std::optional staticTileSize = getConstantIntValue(tileSize); + if ((staticUb && staticTileSize && + staticUb.value() % staticTileSize.value() == 0) || + (staticTileSize.value_or(0) == 1)) { + sizes.push_back(tileSize); + } else { + sizes.push_back( + rewriter + .create( + loc, rewriter.getIndexType(), minMap, + ValueRange{ + getValueOrCreateConstantIndexOp(rewriter, loc, tileSize), + getValueOrCreateConstantIndexOp(rewriter, loc, ub), + iterator}) + .getResult()); + } + } + + SmallVector offsets = + getAsOpFoldResult(newForallOp.getInductionVars()); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + Value sourceTile = rewriter.create( + loc, copy.getSource(), offsets, sizes, strides); + Value targetTile = rewriter.create( + loc, copy.getTarget(), offsets, sizes, strides); + rewriter.replaceOpWithNewOp(copy, sourceTile, targetTile); +} + +static SmallVector getCopyTileSizes(Builder &b, + memref::CopyOp copy) { + int64_t rank = copy.getTarget().getType().getRank(); + if (rank == 0) { + return {}; + } + + SmallVector tileSizes(rank - 1, b.getIndexAttr(1)); + int64_t elementBitWidth = llvm::cast(copy.getTarget().getType()) + .getElementTypeBitWidth(); + tileSizes.push_back(b.getIndexAttr(kPreferredCopyNumBits / elementBitWidth)); + return tileSizes; +} + +} // namespace + +namespace { +struct GPUDistributeCopyUsingForallPass final + : impl::GPUDistributeCopyUsingForallPassBase< + GPUDistributeCopyUsingForallPass> { + void runOnOperation() override { + MLIRContext *context = &getContext(); + auto funcOp = getOperation(); + + SmallVector copies; + + // Walk in PreOrder so that parent operations are visited before children, + // thus allowing all operations contained within thread/warp/lane foralls + // to be skipped. + funcOp.walk([&](Operation *op) { + if (auto forallOp = dyn_cast(op)) { + // Skip ops contained within forall ops with thread/warp/lane mappings. + if (forallOpHasMappingType(forallOp)) { + return WalkResult::skip(); + } + } + if (auto copy = dyn_cast(op)) { + copies.push_back(copy); + } + return WalkResult::advance(); + }); + + IRRewriter rewriter(context); + for (auto copy : copies) { + rewriter.setInsertionPoint(copy); + SmallVector tileSizes = getCopyTileSizes(rewriter, copy); + distributeCopyToThreads(rewriter, copy, tileSizes); + } + } +}; +} // namespace +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td index 7b9f96fa3d02..340fa65f3969 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td @@ -31,6 +31,14 @@ def GPUCreateFastSlowPathPass : let dependentDialects = ["::mlir::scf::SCFDialect"]; } +def GPUDistributeCopyUsingForallPass : + InterfacePass<"iree-codegen-gpu-distribute-copy-using-forall", "mlir::FunctionOpInterface"> { + let summary = "Pass to distribute copies to threads."; + let dependentDialects = [ + "::mlir::affine::AffineDialect", "::mlir::gpu::GPUDialect", "::mlir::scf::SCFDialect" + ]; +} + def GPUDistributeForallPass : InterfacePass<"iree-codegen-gpu-distribute-forall", "mlir::FunctionOpInterface"> { let summary = "Pass to distribute scf.forall ops."; diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index df2e5fee7745..4775dde87bc5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -23,6 +23,7 @@ iree_lit_test_suite( "gpu_check_resource_usage.mlir", "gpu_create_fast_slow_path.mlir", "gpu_distribute.mlir", + "gpu_distribute_copy_using_forall.mlir", "gpu_distribute_forall.mlir", "gpu_distribute_scf_for.mlir", "gpu_distribute_shared_memory.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 2fa5ba26d41a..b43d2ef5a56a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -20,6 +20,7 @@ iree_lit_test_suite( "gpu_combine_value_barriers.mlir" "gpu_create_fast_slow_path.mlir" "gpu_distribute.mlir" + "gpu_distribute_copy_using_forall.mlir" "gpu_distribute_forall.mlir" "gpu_distribute_scf_for.mlir" "gpu_distribute_shared_memory.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_copy_using_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_copy_using_forall.mlir new file mode 100644 index 000000000000..5666f1f4af0c --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_copy_using_forall.mlir @@ -0,0 +1,99 @@ +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-distribute-copy-using-forall))' %s | FileCheck %s + +func.func @static_copy(%src : memref<56x32xf32>, %target : memref<56x32xf32>) { + memref.copy %src, %target : memref<56x32xf32> to memref<56x32xf32> + return +} + +// CHECK-LABEL: func.func @static_copy +// CHECK-SAME: (%[[SRC:.+]]: memref<56x32xf32>, %[[TARGET:.+]]: memref<56x32xf32>) + +// CHECK: scf.forall (%[[IV0:[A-Za-z0-9]+]], %[[IV1:[A-Za-z0-9]+]]) = (0, 0) to (56, 32) step (1, 4) { +// CHECK-DAG: %[[SRC_SUBVIEW:.+]] = memref.subview %[[SRC]][%[[IV0]], %[[IV1]]] [1, 4] [1, 1] +// CHECK-DAG: %[[TARGET_SUBVIEW:.+]] = memref.subview %[[TARGET]][%[[IV0]], %[[IV1]]] [1, 4] [1, 1] +// CHECK: memref.copy %[[SRC_SUBVIEW]], %[[TARGET_SUBVIEW]] +// CHECK: mapping = [#gpu.thread, #gpu.thread] + +// ----- + +func.func @unaligned_copy(%src : memref<56x31xf32>, %target : memref<56x31xf32>) { + memref.copy %src, %target : memref<56x31xf32> to memref<56x31xf32> + return +} + +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// CHECK-LABEL: func.func @unaligned_copy +// CHECK-SAME: (%[[SRC:.+]]: memref<56x31xf32>, %[[TARGET:.+]]: memref<56x31xf32>) + +// CHECK: scf.forall (%[[IV0:[A-Za-z0-9]+]], %[[IV1:[A-Za-z0-9]+]]) = (0, 0) to (56, 31) step (1, 4) { +// CHECK: %[[MIN:.+]] = affine.min #[[$MAP]](%c4, %c31, %[[IV1]]) +// CHECK-DAG: %[[SRC_SUBVIEW:.+]] = memref.subview %[[SRC]][%[[IV0]], %[[IV1]]] [1, %[[MIN]]] +// CHECK-DAG: %[[TARGET_SUBVIEW:.+]] = memref.subview %[[TARGET]][%[[IV0]], %[[IV1]]] [1, %[[MIN]]] +// CHECK: memref.copy %[[SRC_SUBVIEW]], %[[TARGET_SUBVIEW]] +// CHECK: mapping = [#gpu.thread, #gpu.thread] + +// ----- + +func.func @dynamic_copy(%src : memref, %target : memref) { + memref.copy %src, %target : memref to memref + return +} + +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// CHECK-LABEL: func.func @dynamic_copy +// CHECK-SAME: (%[[SRC:.+]]: memref, %[[TARGET:.+]]: memref) + +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %c0 : memref +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[SRC]], %c1 : memref +// CHECK: scf.forall (%[[IV0:[A-Za-z0-9]+]], %[[IV1:[A-Za-z0-9]+]]) = (0, 0) to (%[[D0]], %[[D1]]) step (1, 4) { +// CHECK: %[[MIN:.+]] = affine.min #[[$MAP]](%c4, %[[D1]], %[[IV1]]) +// CHECK-DAG: %[[SRC_SUBVIEW:.+]] = memref.subview %[[SRC]][%[[IV0]], %[[IV1]]] [1, %[[MIN]]] +// CHECK-DAG: %[[TARGET_SUBVIEW:.+]] = memref.subview %[[TARGET]][%[[IV0]], %[[IV1]]] [1, %[[MIN]]] +// CHECK: memref.copy %[[SRC_SUBVIEW]], %[[TARGET_SUBVIEW]] +// CHECK: mapping = [#gpu.thread, #gpu.thread] + +// ----- + +func.func @f16_copy(%src : memref<56x32xf16>, %target : memref<56x32xf16>) { + memref.copy %src, %target : memref<56x32xf16> to memref<56x32xf16> + return +} + +// CHECK-LABEL: func.func @f16_copy +// CHECK-SAME: (%[[SRC:.+]]: memref<56x32xf16>, %[[TARGET:.+]]: memref<56x32xf16>) + +// CHECK: scf.forall (%[[IV0:[A-Za-z0-9]+]], %[[IV1:[A-Za-z0-9]+]]) = (0, 0) to (56, 32) step (1, 8) { +// CHECK-DAG: %[[SRC_SUBVIEW:.+]] = memref.subview %[[SRC]][%[[IV0]], %[[IV1]]] [1, 8] +// CHECK-DAG: %[[TARGET_SUBVIEW:.+]] = memref.subview %[[TARGET]][%[[IV0]], %[[IV1]]] [1, 8] +// CHECK: memref.copy %[[SRC_SUBVIEW]], %[[TARGET_SUBVIEW]] +// CHECK: mapping = [#gpu.thread, #gpu.thread] + +// ----- + +func.func @rank_0_copy(%src : memref, %target : memref) { + memref.copy %src, %target : memref to memref + return +} + +// CHECK-LABEL: func.func @rank_0_copy +// CHECK-SAME: (%[[SRC:.+]]: memref, %[[TARGET:.+]]: memref) + +// CHECK: scf.forall (%{{.*}}) in (1) { +// CHECK: memref.copy %[[SRC]], %[[TARGET]] +// CHECK: mapping = [#gpu.thread] + +// ----- + +func.func @already_distributed_copy(%src : memref<56x32xf32>, %target : memref<56x32xf32>) { + scf.forall (%arg2) in (1) { + memref.copy %src, %target : memref<56x32xf32> to memref<56x32xf32> + } {mapping = [#gpu.thread]} + return +} + +// CHECK-LABEL: func.func @already_distributed_copy +// CHECK-SAME: (%[[SRC:.+]]: memref<56x32xf32>, %[[TARGET:.+]]: memref<56x32xf32>) + +// CHECK: scf.forall (%{{.*}}) in (1) { +// CHECK: memref.copy %[[SRC]], %[[TARGET]] +// CHECK: mapping = [#gpu.thread] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index c1ea8225bba2..8e112af514af 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -452,6 +452,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager, addGPUBufferizePasses(funcPassManager); // Step 8. Resolve remaining parallel loops. + funcPassManager.addPass(createGPUDistributeCopyUsingForallPass()); funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass( NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false, /*normalizeForall=*/true}));