diff --git a/lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp b/lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp index 69bf14300..16bcde710 100644 --- a/lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp +++ b/lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp @@ -5,16 +5,14 @@ #include "lib/Dialect/TensorExt/IR/TensorExtOps.h" #include "lib/Utils/ADT/FrozenVector.h" #include "lib/Utils/Graph/Graph.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #define DEBUG_TYPE "layout-conversion-to-shift-network" @@ -216,7 +214,9 @@ Value rotateGroup(TypedValue tensor, LogicalResult convertLayoutOp(ConvertLayoutOp op, VosVosErkinShiftNetworks &shiftNetworks, int64_t ciphertextSize) { + LLVM_DEBUG(llvm::dbgs() << "Converting layout op: " << op << "\n"); IRRewriter rewriter(op.getContext()); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); // Convert the input and output layouts to an explicit permutation. AffineMap inputLayout = op.getFromLayout().getValue(); @@ -245,9 +245,13 @@ LogicalResult convertLayoutOp(ConvertLayoutOp op, // of the tensors, and mapping fromLayout.eval(index) to // toLayout.eval(index). ArrayRef dims = op.getTensor().getType().getShape(); + // Initial permutation starts as the identity permutation. + // FIXME: start with an empty partial mapping, then extend it to a permutation + // in some way? SmallVector permutation = identity(ciphertextSize); + LLVM_DEBUG(llvm::dbgs() << "Constructing permutation...\n"); // Looking for something like llvm::product_iterator, but found nothing. // Iterating manually and using mod arithmetic to get the per-axis indices. SmallVector indices; @@ -256,9 +260,10 @@ LogicalResult convertLayoutOp(ConvertLayoutOp op, ++index) { // Unflatten the index into dimensional components int dimIndex = dims.size() - 1; + int indexCopy = index; for (int64_t dim : llvm::reverse(dims)) { - indices[dimIndex] = index % dim; - index /= dim; + indices[dimIndex] = indexCopy % dim; + indexCopy /= dim; --dimIndex; } diff --git a/tests/Dialect/TensorExt/Transforms/layout_conversion_to_shift_network.mlir b/tests/Dialect/TensorExt/Transforms/layout_conversion_to_shift_network.mlir new file mode 100644 index 000000000..0ce150637 --- /dev/null +++ b/tests/Dialect/TensorExt/Transforms/layout_conversion_to_shift_network.mlir @@ -0,0 +1,8 @@ +// RUN: heir-opt --layout-conversion-to-shift-network %s | FileCheck %s + +#map1 = affine_map<(d0) -> (d0 + 4 mod 64)> +#map2 = affine_map<(d0) -> (3 * d0 mod 64)> +func.func @test_convert_layout(%0: tensor<64xi32>) -> tensor<64xi32> { + %1 = tensor_ext.convert_layout %0 {from_layout = #map1, to_layout = #map2} : tensor<64xi32> + return %1 : tensor<64xi32> +} diff --git a/tools/BUILD b/tools/BUILD index e5bc74863..8df27a51a 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -84,6 +84,7 @@ cc_binary( "@heir//lib/Dialect/TensorExt/Transforms", "@heir//lib/Dialect/TensorExt/Transforms:CollapseInsertionChains", "@heir//lib/Dialect/TensorExt/Transforms:InsertRotate", + "@heir//lib/Dialect/TensorExt/Transforms:LayoutConversionToShiftNetwork", "@heir//lib/Dialect/TensorExt/Transforms:RotateAndReduce", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 50c55c0ea..9099e92b1 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -46,6 +46,7 @@ #include "lib/Dialect/Secret/Transforms/Passes.h" #include "lib/Dialect/TOSA/Conversions/TosaToSecretArith/TosaToSecretArith.h" #include "lib/Dialect/TensorExt/IR/TensorExtDialect.h" +#include "lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.h" #include "lib/Dialect/TensorExt/Transforms/Passes.h" #include "lib/Dialect/TfheRust/IR/TfheRustDialect.h" #include "lib/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h"