Skip to content

Commit

Permalink
Add LayoutPropagation pass
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Jan 31, 2025
1 parent 62917fa commit 159def7
Show file tree
Hide file tree
Showing 17 changed files with 1,025 additions and 16 deletions.
16 changes: 16 additions & 0 deletions lib/Dialect/TensorExt/IR/TensorExtAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,20 @@ def SIMDPacking_Attr : TensorExt_Attr<"SIMDPacking", "simd_packing",
let assemblyFormat = "`<` struct(params) `>`";
}

def TensorExt_LayoutAttr : TensorExt_Attr<"Layout", "layout"> {
let summary = "Attribute denoting the layout of a tensor in a set of ciphertexts";
let description = [{
This attribute contains an affine map that describes the layout of a tensor
in a set of ciphertexts. The affine map is a function that maps tensor indices
to ciphertext indices (possibly with a ciphertext-selecting index).

This attribute exists primarily to provide a "dialect attribute" which is
required to annotate the arguments of `func.func` arguments.
}];

let parameters = (ins "AffineMap": $layout);
let assemblyFormat = "`<` struct(params) `>`";
}


#endif // LIB_DIALECT_TENSOREXT_IR_TENSOREXTATTRIBUTES_TD_
6 changes: 6 additions & 0 deletions lib/Dialect/TensorExt/IR/TensorExtDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def TensorExt_Dialect : Dialect {
"tensor::TensorDialect",
];

let extraClassDeclaration = [{
constexpr const static ::llvm::StringLiteral
kLayoutAttrName = "tensor_ext.layout";
}];


let useDefaultAttributePrinterParser = 1;
}

Expand Down
46 changes: 31 additions & 15 deletions lib/Dialect/TensorExt/IR/TensorExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,44 @@ LogicalResult RotateOp::verify() {
return success();
}

LogicalResult ConvertLayoutOp::verify() {
int64_t rank = cast<RankedTensorType>(getTensor().getType()).getRank();
const AffineMap &fromLayout = getFromLayout().getValue();
const AffineMap &toLayout = getToLayout().getValue();

if (rank != fromLayout.getNumDims() || rank != toLayout.getNumDims()) {
std::string fromLayoutStr, toLayoutStr;
llvm::raw_string_ostream fromLayoutStream(fromLayoutStr),
toLayoutStream(toLayoutStr);
fromLayout.print(fromLayoutStream);
toLayout.print(toLayoutStream);
LogicalResult verifyLayoutMatchesType(const AffineMap &layout, Type type,
Operation *op) {
int64_t rank = cast<ShapedType>(type).getRank();
if (rank != layout.getNumDims()) {
std::string layoutStr;
llvm::raw_string_ostream os(layoutStr);
layout.print(os);

return emitOpError()
return op->emitOpError()
<< "requires tensor rank to match the layout map's dimension count"
"but found rank "
<< rank << " and maps " << fromLayoutStream.str() << " and "
<< toLayoutStream.str();
" but found rank "
<< rank << " and map " << os.str();
}

return success();
}

LogicalResult ConvertLayoutOp::verify() {
LogicalResult inputVerification = verifyLayoutMatchesType(
getFromLayout().getValue(), getTensor().getType(), *this);
if (failed(inputVerification)) {
return inputVerification;
}

LogicalResult outputVerification = verifyLayoutMatchesType(
getToLayout().getValue(), getResult().getType(), *this);
if (failed(outputVerification)) {
return outputVerification;
}

return success();
}

LogicalResult AssignLayoutOp::verify() {
return verifyLayoutMatchesType(getLayout().getValue(), getTensor().getType(),
*this);
}

} // namespace tensor_ext
} // namespace heir
} // namespace mlir
21 changes: 21 additions & 0 deletions lib/Dialect/TensorExt/IR/TensorExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,31 @@ def TensorExt_ConvertLayoutOp : TensorExt_Op<"convert_layout", [Pure, AllTypesMa

This op is inserted by layout selection passes.
}];

let assemblyFormat = "operands attr-dict `:` type($output)";
let arguments = (ins AnyRankedTensor:$tensor, Builtin_AffineMapAttr:$from_layout, Builtin_AffineMapAttr:$to_layout);
let results = (outs AnyRankedTensor:$output);
let hasVerifier = 1;
}

def TensorExt_AssignLayoutOp : TensorExt_Op<"assign_layout", [Pure, AllTypesMatch<["tensor", "output"]>]> {
let summary = "Assign a layout to a plaintext tensor.";
let description = [{
This op allows the ingestion of a plaintext tensor into the layout system.
For example, ops like `linalg.reduce`, require a tensor input to represent
initial values. These will generally be created by an `arith.constant` or
`tensor.empty` op, which does not have secret results. Lowerings will
convert this to a packed plaintext, so that the subsequent ops can be
lowered as ciphertext-plaintext ops.

This op is inserted by layout selection passes.
}];

let assemblyFormat = "operands attr-dict `:` type($output)";
let arguments = (ins AnyRankedTensor:$tensor, Builtin_AffineMapAttr:$layout);
let results = (outs AnyRankedTensor:$output);
let hasVerifier = 1;
}


#endif // LIB_DIALECT_TENSOREXT_IR_TENSOREXTOPS_TD_
12 changes: 11 additions & 1 deletion lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
#include "lib/Dialect/TensorExt/Transforms/RotateAndReduce.h"
#include "lib/Pipelines/PipelineRegistration.h"
#include "lib/Transforms/ApplyFolders/ApplyFolders.h"
#include "lib/Transforms/DropUnitDims/DropUnitDims.h"
#include "lib/Transforms/ForwardStoreToLoad/ForwardStoreToLoad.h"
#include "lib/Transforms/FullLoopUnroll/FullLoopUnroll.h"
#include "lib/Transforms/LayoutPropagation/LayoutPropagation.h"
#include "lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.h"
#include "lib/Transforms/OperationBalancer/OperationBalancer.h"
#include "lib/Transforms/OptimizeRelinearization/OptimizeRelinearization.h"
Expand Down Expand Up @@ -82,8 +85,15 @@ void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm) {
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

// Apply linalg kernels
// Linalg canonicalization
// TODO(#1191): enable dropping unit dims to convert matmul to matvec/vecmat
// pm.addPass(createDropUnitDims());
pm.addPass(createLinalgCanonicalizations());

// Layout assignment and lowering
// TODO(#1191): enable layout propagation after implementing the rest
// of the layout lowering pipeline.
// pm.addPass(createLayoutPropagation());
pm.addPass(heir::linalg::createLinalgToTensorExt());

// Vectorize and optimize rotations
Expand Down
2 changes: 2 additions & 0 deletions lib/Pipelines/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ cc_library(
"@heir//lib/Dialect/TensorExt/Transforms:InsertRotate",
"@heir//lib/Dialect/TensorExt/Transforms:RotateAndReduce",
"@heir//lib/Transforms/ApplyFolders",
"@heir//lib/Transforms/DropUnitDims",
"@heir//lib/Transforms/ForwardStoreToLoad",
"@heir//lib/Transforms/FullLoopUnroll",
"@heir//lib/Transforms/LayoutPropagation",
"@heir//lib/Transforms/LinalgCanonicalizations",
"@heir//lib/Transforms/MemrefToArith:ExpandCopy",
"@heir//lib/Transforms/MemrefToArith:MemrefToArithRegistration",
Expand Down
32 changes: 32 additions & 0 deletions lib/Transforms/LayoutPropagation/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
load("@heir//lib/Transforms:transforms.bzl", "add_heir_transforms")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "LayoutPropagation",
srcs = ["LayoutPropagation.cpp"],
hdrs = ["LayoutPropagation.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Analysis/SecretnessAnalysis",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)

add_heir_transforms(
generated_target_name = "pass_inc_gen",
pass_name = "LayoutPropagation",
td_file = "LayoutPropagation.td",
)
Loading

0 comments on commit 159def7

Please sign in to comment.