Skip to content

Commit

Permalink
start debugging and thinking hard
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Feb 1, 2025
1 parent de67600 commit 03af728
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
29 changes: 17 additions & 12 deletions lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
#include "lib/Utils/ADT/FrozenVector.h"
#include "lib/Utils/Graph/Graph.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

#define DEBUG_TYPE "layout-conversion-to-shift-network"

Expand Down Expand Up @@ -216,7 +214,9 @@ Value rotateGroup(TypedValue<RankedTensorType> tensor,
LogicalResult convertLayoutOp(ConvertLayoutOp op,
VosVosErkinShiftNetworks &shiftNetworks,
int64_t ciphertextSize) {
LLVM_DEBUG(llvm::dbgs() << "Converting layout op: " << op << "\n");
IRRewriter rewriter(op.getContext());
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

// Convert the input and output layouts to an explicit permutation.
AffineMap inputLayout = op.getFromLayout().getValue();
Expand Down Expand Up @@ -245,9 +245,13 @@ LogicalResult convertLayoutOp(ConvertLayoutOp op,
// of the tensors, and mapping fromLayout.eval(index) to
// toLayout.eval(index).
ArrayRef<int64_t> dims = op.getTensor().getType().getShape();

// Initial permutation starts as the identity permutation.
// FIXME: start with an empty partial mapping, then extend it to a permutation
// in some way?
SmallVector<int64_t> permutation = identity(ciphertextSize);

LLVM_DEBUG(llvm::dbgs() << "Constructing permutation...\n");
// Looking for something like llvm::product_iterator, but found nothing.
// Iterating manually and using mod arithmetic to get the per-axis indices.
SmallVector<int64_t, 4> indices;
Expand All @@ -256,9 +260,10 @@ LogicalResult convertLayoutOp(ConvertLayoutOp op,
++index) {
// Unflatten the index into dimensional components
int dimIndex = dims.size() - 1;
int indexCopy = index;
for (int64_t dim : llvm::reverse(dims)) {
indices[dimIndex] = index % dim;
index /= dim;
indices[dimIndex] = indexCopy % dim;
indexCopy /= dim;
--dimIndex;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: heir-opt --layout-conversion-to-shift-network %s | FileCheck %s

#map1 = affine_map<(d0) -> (d0 + 4 mod 64)>
#map2 = affine_map<(d0) -> (3 * d0 mod 64)>
func.func @test_convert_layout(%0: tensor<64xi32>) -> tensor<64xi32> {
%1 = tensor_ext.convert_layout %0 {from_layout = #map1, to_layout = #map2} : tensor<64xi32>
return %1 : tensor<64xi32>
}
1 change: 1 addition & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ cc_binary(
"@heir//lib/Dialect/TensorExt/Transforms",
"@heir//lib/Dialect/TensorExt/Transforms:CollapseInsertionChains",
"@heir//lib/Dialect/TensorExt/Transforms:InsertRotate",
"@heir//lib/Dialect/TensorExt/Transforms:LayoutConversionToShiftNetwork",
"@heir//lib/Dialect/TensorExt/Transforms:RotateAndReduce",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Dialect/TfheRustBool/IR:Dialect",
Expand Down
1 change: 1 addition & 0 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "lib/Dialect/Secret/Transforms/Passes.h"
#include "lib/Dialect/TOSA/Conversions/TosaToSecretArith/TosaToSecretArith.h"
#include "lib/Dialect/TensorExt/IR/TensorExtDialect.h"
#include "lib/Dialect/TensorExt/Transforms/LayoutConversionToShiftNetwork.h"
#include "lib/Dialect/TensorExt/Transforms/Passes.h"
#include "lib/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "lib/Dialect/TfheRustBool/IR/TfheRustBoolDialect.h"
Expand Down

0 comments on commit 03af728

Please sign in to comment.