Skip to content

Commit

Permalink
feat: adds pass to convert secret annotated arguments to secret types
Browse files Browse the repository at this point in the history
and wrap function body in secret.generic

Signed-off-by: Asra <asraa@google.com>
  • Loading branch information
asraa committed Dec 1, 2023
1 parent 5b025b1 commit 3c10203
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 15 deletions.
8 changes: 4 additions & 4 deletions include/Transforms/Secretize/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ package(
)

exports_files([
"Secretize.h",
"Passes.h",
])

gentbl_cc_library(
Expand All @@ -19,15 +19,15 @@ gentbl_cc_library(
"-gen-pass-decls",
"-name=Secretize",
],
"Secretize.h.inc",
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"Secretize.md",
"Passes.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Secretize.td",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "include/Transforms/Secretize/Secretize.h.inc"
#include "include/Transforms/Secretize/Passes.h.inc"

#define GEN_PASS_REGISTRATION
#include "include/Transforms/Secretize/Secretize.h.inc"
#include "include/Transforms/Secretize/Passes.h.inc"

} // namespace heir
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_
#define INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_
#ifndef INCLUDE_TRANSFORMS_SECRETIZE_PASSES_TD_
#define INCLUDE_TRANSFORMS_SECRETIZE_PASSES_TD_

include "mlir/Pass/PassBase.td"

Expand All @@ -22,4 +22,19 @@ def Secretize : Pass<"secretize", "ModuleOp"> {
];
}

#endif // INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_
def WrapGeneric : Pass<"wrap-generic", "ModuleOp"> {
let summary = "Wraps regions using secret args in secret.generic bodies";

let description = [{
This pass wraps regions that use secret arguments in secret.generic bodies.

Secret arguments are annotated using a `secret.secret` argument attribute.
}];

let dependentDialects = [
"mlir::heir::secret::SecretDialect",
"mlir::func::FuncDialect"
];
}

#endif // INCLUDE_TRANSFORMS_SECRETIZE_PASSES_TD_
7 changes: 5 additions & 2 deletions lib/Transforms/Secretize/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ package(

cc_library(
name = "Secretize",
srcs = ["Secretize.cpp"],
srcs = [
"Secretize.cpp",
"WrapGeneric.cpp",
],
hdrs = [
"@heir//include/Transforms/Secretize:Secretize.h",
"@heir//include/Transforms/Secretize:Passes.h",
],
deps = [
"@heir//include/Transforms/Secretize:pass_inc_gen",
Expand Down
5 changes: 2 additions & 3 deletions lib/Transforms/Secretize/Secretize.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include "include/Transforms/Secretize/Secretize.h"

#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Transforms/Secretize/Passes.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
Expand All @@ -9,7 +8,7 @@ namespace mlir {
namespace heir {

#define GEN_PASS_DEF_SECRETIZE
#include "include/Transforms/Secretize/Secretize.h.inc"
#include "include/Transforms/Secretize/Passes.h.inc"

struct Secretize : impl::SecretizeBase<Secretize> {
using SecretizeBase::SecretizeBase;
Expand Down
92 changes: 92 additions & 0 deletions lib/Transforms/Secretize/WrapGeneric.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include <iostream>

#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/Secret/IR/SecretTypes.h"
#include "include/Transforms/Secretize/Passes.h"
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DEF_WRAPGENERIC
#include "include/Transforms/Secretize/Passes.h.inc"

struct WrapWithGeneric : public OpRewritePattern<func::FuncOp> {
WrapWithGeneric(mlir::MLIRContext *context)
: mlir::OpRewritePattern<func::FuncOp>(context) {}

LogicalResult matchAndRewrite(func::FuncOp op,
PatternRewriter &rewriter) const override {
bool hasSecrets = false;

SmallVector<Type, 4> newInputs;
for (unsigned i = 0; i < op.getNumArguments(); i++) {
if (op.getArgAttr(i, secret::SecretDialect::kArgSecretAttrName) !=
nullptr) {
hasSecrets = true;
op.removeArgAttr(i, secret::SecretDialect::kArgSecretAttrName);

auto newTy = secret::SecretType::get(op.getArgument(i).getType());
op.getArgument(i).setType(newTy); // Updates the block argument type.
newInputs.push_back(newTy);
}
}

if (!hasSecrets) {
// Match failure, no secret inputs.
return failure();
}

auto newOutputs = llvm::to_vector<6>(llvm::map_range(
op.getResultTypes(),
[](Type t) -> Type { return secret::SecretType::get(t); }));

op.setFunctionType(
FunctionType::get(getContext(), {newInputs}, {newOutputs}));

// Create a secret.generic op and pull the original function block in.
Block &opEntryBlock = op.getRegion().front();
rewriter.setInsertionPointToStart(&opEntryBlock);
auto newGeneric = rewriter.create<secret::GenericOp>(
op.getLoc(), op.getArguments(), newOutputs,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
// Map the input values to the block arguments.
IRMapping mp;
for (unsigned i = 0; i < blockArguments.size(); ++i) {
mp.map(opEntryBlock.getArgument(i), blockArguments[i]);
}
for (auto &entryOp : opEntryBlock.getOperations()) {
b.clone(entryOp, mp);
}
auto *returnOp = b.getBlock()->getTerminator();
b.create<secret::YieldOp>(loc, returnOp->getOperands());
returnOp->erase();
});

rewriter.replaceOp(
opEntryBlock.getTerminator(),
rewriter.create<func::ReturnOp>(op.getLoc(), newGeneric.getResults()));

return success();
}
};

struct WrapGeneric : impl::WrapGenericBase<WrapGeneric> {
using WrapGenericBase::WrapGenericBase;

void runOnOperation() override {
MLIRContext *context = &getContext();

mlir::RewritePatternSet patterns(context);
patterns.add<WrapWithGeneric>(context);

(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

} // namespace heir
} // namespace mlir
36 changes: 36 additions & 0 deletions tests/secretize/wrap_generic.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: heir-opt --split-input-file --wrap-generic %s | FileCheck %s

// CHECK: module
module {
// CHECK: @main(%arg0: !secret.secret<i32>, %arg1: !secret.secret<i1>) -> !secret.secret<i32>
func.func @main(%value: i32 {secret.secret}, %cond: i1 {secret.secret}) -> (i32) {
// CHECK: %[[V0:.*]] = secret.generic
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c7 = arith.constant 7 : i32
%0 = arith.muli %value, %c7 : i32
%1 = arith.addi %0, %c1 : i32
%2 = arith.muli %1, %1 : i32
%3 = arith.select %cond, %2, %c0 : i32
// CHECK: return %[[V0]] : !secret.secret<i32>
func.return %3 : i32
}
}

// -----

module {
// CHECK: @main(%arg0: !secret.secret<i32>) -> (!secret.secret<i1>, !secret.secret<i32>)
func.func @main(%value: i32 {secret.secret}) -> (i1, i32) {
// CHECK: %[[V0:.*]]:2 = secret.generic
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c7 = arith.constant 7 : i32
%0 = arith.muli %value, %c7 : i32
%1 = arith.addi %0, %c1 : i32
%2 = arith.muli %1, %1 : i32
%3 = arith.cmpi slt, %value, %c0 : i32
// CHECK: return %[[V0]]#0, %[[V0]]#1 : !secret.secret<i1>, !secret.secret<i32>
func.return %3, %2 : i1, i32
}
}
2 changes: 1 addition & 1 deletion tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/Transforms/Passes.h"
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Transforms/Secretize/Secretize.h"
#include "include/Transforms/Secretize/Passes.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
Expand Down

0 comments on commit 3c10203

Please sign in to comment.