From b1db154e6224b7c72270308d4acf4e1a5557012c Mon Sep 17 00:00:00 2001 From: Asra Date: Thu, 30 Nov 2023 15:20:53 +0000 Subject: [PATCH] feat: adds pass to convert secret annotated arguments to secret types and wrap function body in secret.generic Signed-off-by: Asra --- include/Transforms/Secretize/BUILD | 8 +- .../Secretize/{Secretize.h => Passes.h} | 4 +- .../Secretize/{Secretize.td => Passes.td} | 21 ++++- lib/Transforms/Secretize/BUILD | 7 +- lib/Transforms/Secretize/Secretize.cpp | 5 +- lib/Transforms/Secretize/WrapGeneric.cpp | 93 +++++++++++++++++++ tests/secretize/wrap_generic.mlir | 15 +++ tools/heir-opt.cpp | 2 +- 8 files changed, 140 insertions(+), 15 deletions(-) rename include/Transforms/Secretize/{Secretize.h => Passes.h} (75%) rename include/Transforms/Secretize/{Secretize.td => Passes.td} (52%) create mode 100644 lib/Transforms/Secretize/WrapGeneric.cpp create mode 100644 tests/secretize/wrap_generic.mlir diff --git a/include/Transforms/Secretize/BUILD b/include/Transforms/Secretize/BUILD index 2e41683ff1..b8b2dfe22a 100644 --- a/include/Transforms/Secretize/BUILD +++ b/include/Transforms/Secretize/BUILD @@ -8,7 +8,7 @@ package( ) exports_files([ - "Secretize.h", + "Passes.h", ]) gentbl_cc_library( @@ -19,15 +19,15 @@ gentbl_cc_library( "-gen-pass-decls", "-name=Secretize", ], - "Secretize.h.inc", + "Passes.h.inc", ), ( ["-gen-pass-doc"], - "Secretize.md", + "Passes.md", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "Secretize.td", + td_file = "Passes.td", deps = [ "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:PassBaseTdFiles", diff --git a/include/Transforms/Secretize/Secretize.h b/include/Transforms/Secretize/Passes.h similarity index 75% rename from include/Transforms/Secretize/Secretize.h rename to include/Transforms/Secretize/Passes.h index 7098c1fbc5..c5a33b4ca0 100644 --- a/include/Transforms/Secretize/Secretize.h +++ b/include/Transforms/Secretize/Passes.h @@ -7,10 +7,10 @@ namespace mlir { namespace heir { #define GEN_PASS_DECL -#include "include/Transforms/Secretize/Secretize.h.inc" +#include "include/Transforms/Secretize/Passes.h.inc" #define GEN_PASS_REGISTRATION -#include "include/Transforms/Secretize/Secretize.h.inc" +#include "include/Transforms/Secretize/Passes.h.inc" } // namespace heir } // namespace mlir diff --git a/include/Transforms/Secretize/Secretize.td b/include/Transforms/Secretize/Passes.td similarity index 52% rename from include/Transforms/Secretize/Secretize.td rename to include/Transforms/Secretize/Passes.td index 2bdfc6bd90..d87aa94554 100644 --- a/include/Transforms/Secretize/Secretize.td +++ b/include/Transforms/Secretize/Passes.td @@ -1,5 +1,5 @@ -#ifndef INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ -#define INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ +#ifndef INCLUDE_TRANSFORMS_SECRETIZE_PASSES_TD_ +#define INCLUDE_TRANSFORMS_SECRETIZE_PASSES_TD_ include "mlir/Pass/PassBase.td" @@ -22,4 +22,19 @@ def Secretize : Pass<"secretize", "ModuleOp"> { ]; } -#endif // INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ +def WrapGeneric : Pass<"wrap-generic", "ModuleOp"> { + let summary = "Wraps regions using secret args in secret.generic bodies"; + + let description = [{ + This pass wraps regions that use secret arguments in secret.generic bodies. + + Secret arguments are annotated using a `secret.secret` argument attribute. + }]; + + let dependentDialects = [ + "mlir::heir::secret::SecretDialect", + "mlir::func::FuncDialect" + ]; +} + +#endif // INCLUDE_TRANSFORMS_SECRETIZE_PASSES_TD_ diff --git a/lib/Transforms/Secretize/BUILD b/lib/Transforms/Secretize/BUILD index e8c6528999..ebea6e812b 100644 --- a/lib/Transforms/Secretize/BUILD +++ b/lib/Transforms/Secretize/BUILD @@ -5,9 +5,12 @@ package( cc_library( name = "Secretize", - srcs = ["Secretize.cpp"], + srcs = [ + "Secretize.cpp", + "WrapGeneric.cpp", + ], hdrs = [ - "@heir//include/Transforms/Secretize:Secretize.h", + "@heir//include/Transforms/Secretize:Passes.h", ], deps = [ "@heir//include/Transforms/Secretize:pass_inc_gen", diff --git a/lib/Transforms/Secretize/Secretize.cpp b/lib/Transforms/Secretize/Secretize.cpp index 21bcf99b9c..b798625425 100644 --- a/lib/Transforms/Secretize/Secretize.cpp +++ b/lib/Transforms/Secretize/Secretize.cpp @@ -1,6 +1,5 @@ -#include "include/Transforms/Secretize/Secretize.h" - #include "include/Dialect/Secret/IR/SecretDialect.h" +#include "include/Transforms/Secretize/Passes.h" #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project @@ -9,7 +8,7 @@ namespace mlir { namespace heir { #define GEN_PASS_DEF_SECRETIZE -#include "include/Transforms/Secretize/Secretize.h.inc" +#include "include/Transforms/Secretize/Passes.h.inc" struct Secretize : impl::SecretizeBase { using SecretizeBase::SecretizeBase; diff --git a/lib/Transforms/Secretize/WrapGeneric.cpp b/lib/Transforms/Secretize/WrapGeneric.cpp new file mode 100644 index 0000000000..721976b8e0 --- /dev/null +++ b/lib/Transforms/Secretize/WrapGeneric.cpp @@ -0,0 +1,93 @@ +#include + +#include "include/Dialect/Secret/IR/SecretDialect.h" +#include "include/Dialect/Secret/IR/SecretOps.h" +#include "include/Dialect/Secret/IR/SecretTypes.h" +#include "include/Transforms/Secretize/Passes.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_WRAPGENERIC +#include "include/Transforms/Secretize/Passes.h.inc" + +struct WrapWithGeneric : public OpRewritePattern { + WrapWithGeneric(mlir::MLIRContext *context) + : mlir::OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const override { + bool hasSecrets = false; + + SmallVector newInputs; + for (unsigned i = 0; i < op.getNumArguments(); i++) { + if (op.getArgAttr(i, secret::SecretDialect::kArgSecretAttrName) != + nullptr) { + hasSecrets = true; + op.removeArgAttr(i, secret::SecretDialect::kArgSecretAttrName); + + auto newTy = secret::SecretType::get(op.getArgument(i).getType()); + op.getArgument(i).setType(newTy); // Updates the block argument type. + newInputs.push_back(newTy); + } + } + + if (!hasSecrets) { + // Match failure, no secret inputs. + return failure(); + } + + llvm::SmallVector newOutputs; + for (unsigned i = 0; i < op.getNumResults(); i++) { + newOutputs.push_back(secret::SecretType::get(op.getResultTypes()[i])); + } + + op.setFunctionType( + FunctionType::get(getContext(), {newInputs}, {newOutputs})); + + // Create a secret.generic op and pull the original function block in. + Block &opEntryBlock = op.getRegion().front(); + rewriter.setInsertionPointToStart(&opEntryBlock); + auto newGeneric = rewriter.create( + op.getLoc(), op.getArguments(), newOutputs, + [&](OpBuilder &b, Location loc, ValueRange blockArguments) { + // Map the input values to the block arguments. + IRMapping mp; + for (unsigned i = 0; i < blockArguments.size(); ++i) { + mp.map(opEntryBlock.getArgument(i), blockArguments[i]); + } + for (auto &entryOp : opEntryBlock.getOperations()) { + b.clone(entryOp, mp); + } + auto *returnOp = b.getBlock()->getTerminator(); + b.create(loc, returnOp->getOperands()); + returnOp->erase(); + }); + + rewriter.replaceOp( + opEntryBlock.getTerminator(), + rewriter.create(op.getLoc(), newGeneric.getResults())); + + return success(); + } +}; + +struct WrapGeneric : impl::WrapGenericBase { + using WrapGenericBase::WrapGenericBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + + mlir::RewritePatternSet patterns(context); + patterns.add(context); + + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace heir +} // namespace mlir diff --git a/tests/secretize/wrap_generic.mlir b/tests/secretize/wrap_generic.mlir new file mode 100644 index 0000000000..b0bec361a7 --- /dev/null +++ b/tests/secretize/wrap_generic.mlir @@ -0,0 +1,15 @@ +// RUN: heir-opt --wrap-generic %s | FileCheck %s + +// CHECK: module +module { + func.func @main(%value: i32 {secret.secret}, %cond: i1 {secret.secret}) -> (i32) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c7 = arith.constant 7 : i32 + %0 = arith.muli %value, %c7 : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.muli %1, %1 : i32 + %3 = arith.select %cond, %2, %c0 : i32 + func.return %3 : i32 + } +} diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 9090654216..d2519f9ef3 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -16,7 +16,7 @@ #include "include/Dialect/Secret/IR/SecretDialect.h" #include "include/Dialect/Secret/Transforms/Passes.h" #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" -#include "include/Transforms/Secretize/Secretize.h" +#include "include/Transforms/Secretize/Passes.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project