diff --git a/include/Dialect/Secret/IR/SecretDialect.td b/include/Dialect/Secret/IR/SecretDialect.td index 72ce6835b8..a7f3a5963c 100644 --- a/include/Dialect/Secret/IR/SecretDialect.td +++ b/include/Dialect/Secret/IR/SecretDialect.td @@ -15,6 +15,13 @@ def Secret_Dialect : Dialect { custom types for arithmetic on secret integers of various bit widths. }]; + let extraClassDeclaration = [{ + /// Name of the attribute indicate whether an argument of a function is a + //secret. + constexpr const static ::llvm::StringLiteral + kArgSecretAttrName = "secret.secret"; + }]; + let cppNamespace = "::mlir::heir::secret"; let useDefaultTypePrinterParser = 1; } diff --git a/include/Transforms/Secretize/BUILD b/include/Transforms/Secretize/BUILD new file mode 100644 index 0000000000..2e41683ff1 --- /dev/null +++ b/include/Transforms/Secretize/BUILD @@ -0,0 +1,35 @@ +# Secretize tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files([ + "Secretize.h", +]) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=Secretize", + ], + "Secretize.h.inc", + ), + ( + ["-gen-pass-doc"], + "Secretize.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Secretize.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Transforms/Secretize/Secretize.h b/include/Transforms/Secretize/Secretize.h new file mode 100644 index 0000000000..7098c1fbc5 --- /dev/null +++ b/include/Transforms/Secretize/Secretize.h @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_ +#define INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "include/Transforms/Secretize/Secretize.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Transforms/Secretize/Secretize.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_H_ diff --git a/include/Transforms/Secretize/Secretize.td b/include/Transforms/Secretize/Secretize.td new file mode 100644 index 0000000000..125f8ba9f2 --- /dev/null +++ b/include/Transforms/Secretize/Secretize.td @@ -0,0 +1,25 @@ +#ifndef INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ +#define INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ + +include "mlir/Pass/PassBase.td" + +def Secretize : Pass<"secretize", "ModuleOp"> { + let summary = "Adds secret argument attributes to entry function"; + + let description = [{ + Adds a secret.secret attribute argument to each argument in the entry + function of an MLIR module. By default, the function is `main`. This may be + overrided with the option -entry-function=top_level_func. + }]; + + let dependentDialects = [ + "mlir::heir::secret::SecretDialect", + "mlir::func::FuncDialect" + ]; + + let options = [ + Option<"entryFunction", "entry-function", "std::string", "\"main\"", "entry function of the module"> + ]; +} + +#endif // INCLUDE_TRANSFORMS_SECRETIZE_SECRETIZE_TD_ diff --git a/lib/Transforms/Secretize/BUILD b/lib/Transforms/Secretize/BUILD new file mode 100644 index 0000000000..e8c6528999 --- /dev/null +++ b/lib/Transforms/Secretize/BUILD @@ -0,0 +1,21 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Secretize", + srcs = ["Secretize.cpp"], + hdrs = [ + "@heir//include/Transforms/Secretize:Secretize.h", + ], + deps = [ + "@heir//include/Transforms/Secretize:pass_inc_gen", + "@heir//lib/Dialect/Secret/IR:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Transforms/Secretize/Secretize.cpp b/lib/Transforms/Secretize/Secretize.cpp new file mode 100644 index 0000000000..21bcf99b9c --- /dev/null +++ b/lib/Transforms/Secretize/Secretize.cpp @@ -0,0 +1,39 @@ +#include "include/Transforms/Secretize/Secretize.h" + +#include "include/Dialect/Secret/IR/SecretDialect.h" +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_SECRETIZE +#include "include/Transforms/Secretize/Secretize.h.inc" + +struct Secretize : impl::SecretizeBase { + using SecretizeBase::SecretizeBase; + + void runOnOperation() override { + MLIRContext* ctx = &getContext(); + ModuleOp module = getOperation(); + OpBuilder builder(module); + + auto mainFunction = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, entryFunction)); + if (!mainFunction) { + module.emitError("could not find entry point function"); + signalPassFailure(); + return; + } + + auto secretArgAttr = + StringAttr::get(ctx, secret::SecretDialect::kArgSecretAttrName); + for (unsigned i = 0; i < mainFunction.getNumArguments(); i++) { + mainFunction.setArgAttr(i, secretArgAttr, UnitAttr::get(ctx)); + } + } +}; + +} // namespace heir +} // namespace mlir diff --git a/tests/secretize/BUILD b/tests/secretize/BUILD new file mode 100644 index 0000000000..6c9032391a --- /dev/null +++ b/tests/secretize/BUILD @@ -0,0 +1,13 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/secretize/main.mlir b/tests/secretize/main.mlir new file mode 100644 index 0000000000..86b4d74531 --- /dev/null +++ b/tests/secretize/main.mlir @@ -0,0 +1,22 @@ +// RUN: heir-opt -secretize %s | FileCheck %s + +module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + // CHECK: func.func @main(%arg0: tensor<1x1xi8> {secret.secret, tf_saved_model.index_path = ["dense_input"]}) + func.func @main(%arg0: tensor<1x1xi8> {tf_saved_model.index_path = ["dense_input"]}) -> (tensor<1x1xi8> {tf_saved_model.index_path = ["dense_2"]}) { + %0 = "tosa.const"() <{value = dense<429> : tensor<1xi32>}> : () -> tensor<1xi32> + %1 = "tosa.const"() <{value = dense<[[-39, 59, 39, 21, 28, -32, -34, -35, 15, 27, -59, -41, 18, -35, -7, 127]]> : tensor<1x16xi8>}> : () -> tensor<1x16xi8> + %2 = "tosa.const"() <{value = dense<[-729, 1954, 610, 0, 241, -471, -35, -867, 571, 581, 4260, 3943, 591, 0, -889, -5103]> : tensor<16xi32>}> : () -> tensor<16xi32> + %3 = "tosa.const"() <{value = dense<"0xF41AED091921F424E021EFBCF7F5FA1903DCD20206F9F402FFFAEFF1EFD327E1FB27DDEBDBE4051A17FC241215EF1EE410FE14DA1CF8F3F1EFE2F309E3E9EDE3E415070B041B1AFEEB01DE21E60BEC03230A22241E2703E60324FFC011F8FCF1110CF5E0F30717E5E8EDFADCE823FB07DDFBFD0014261117E7F111EA0226040425211D0ADB1DDC2001FAE3370BF11A16EF1CE703E01602032118092ED9E5140BEA1AFCD81300C4D8ECD9FE0D1920D8D6E21FE9D7CAE2DDC613E7043E000114C7DBE71515F506D61ADC0922FE080213EF191EE209FDF314DDDA20D90FE3F9F7EEE924E629000716E21E0D23D3DDF714FA0822262109080F0BE012F47FDC58E526"> : tensor<16x16xi8>}> : () -> tensor<16x16xi8> + %4 = "tosa.const"() <{value = dense<[0, 0, -5438, -5515, -1352, -1500, -4152, -84, 3396, 0, 1981, -5581, 0, -6964, 3407, -7217]> : tensor<16xi32>}> : () -> tensor<16xi32> + %5 = "tosa.const"() <{value = dense<[[-9], [-54], [57], [71], [104], [115], [98], [99], [64], [-26], [127], [25], [-82], [68], [95], [86]]> : tensor<16x1xi8>}> : () -> tensor<16x1xi8> + %6 = "tosa.fully_connected"(%arg0, %5, %4) <{quantization_info = #tosa.conv_quant}> : (tensor<1x1xi8>, tensor<16x1xi8>, tensor<16xi32>) -> tensor<1x16xi32> + %7 = "tosa.rescale"(%6) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array}> : (tensor<1x16xi32>) -> tensor<1x16xi8> + %8 = "tosa.clamp"(%7) <{max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64}> : (tensor<1x16xi8>) -> tensor<1x16xi8> + %9 = "tosa.fully_connected"(%8, %3, %2) <{quantization_info = #tosa.conv_quant}> : (tensor<1x16xi8>, tensor<16x16xi8>, tensor<16xi32>) -> tensor<1x16xi32> + %10 = "tosa.rescale"(%9) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array}> : (tensor<1x16xi32>) -> tensor<1x16xi8> + %11 = "tosa.clamp"(%10) <{max_fp = 0.000000e+00 : f32, max_int = 127 : i64, min_fp = 0.000000e+00 : f32, min_int = -128 : i64}> : (tensor<1x16xi8>) -> tensor<1x16xi8> + %12 = "tosa.fully_connected"(%11, %1, %0) <{quantization_info = #tosa.conv_quant}> : (tensor<1x16xi8>, tensor<1x16xi8>, tensor<1xi32>) -> tensor<1x1xi32> + %13 = "tosa.rescale"(%12) <{double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array}> : (tensor<1x1xi32>) -> tensor<1x1xi8> + return %13 : tensor<1x1xi8> + } +} diff --git a/tests/secretize/missing.mlir b/tests/secretize/missing.mlir new file mode 100644 index 0000000000..154e15bbfe --- /dev/null +++ b/tests/secretize/missing.mlir @@ -0,0 +1,9 @@ +// RUN: heir-opt -secretize -verify-diagnostics %s + +// expected-error@+1 {{could not find entry point function}} +module { + func.func @comb(%a: i1, %b: i1) -> () { + %0 = comb.truth_table %a, %b -> 6 : ui4 + return + } +} diff --git a/tests/secretize/multiple.mlir b/tests/secretize/multiple.mlir new file mode 100644 index 0000000000..686b94747e --- /dev/null +++ b/tests/secretize/multiple.mlir @@ -0,0 +1,15 @@ +// RUN: heir-opt -secretize %s | FileCheck %s + +module { + // CHECK: func.func @inner(%arg0: i1, %arg1: i1) + func.func @inner(%a: i1, %b: i1) -> () { + %0 = comb.truth_table %a, %b -> 6 : ui4 + return + } + + // CHECK: func.func @main(%arg0: i1 {secret.secret}, %arg1: i1 {secret.secret}) + func.func @main(%a: i1, %b: i1) -> () { + func.call @inner(%a, %b) : (i1, i1) -> () + return + } +} diff --git a/tests/secretize/named.mlir b/tests/secretize/named.mlir new file mode 100644 index 0000000000..26a024281c --- /dev/null +++ b/tests/secretize/named.mlir @@ -0,0 +1,9 @@ +// RUN: heir-opt -secretize=entry-function=comb %s | FileCheck %s + +module { + // CHECK: func.func @comb(%arg0: i1 {secret.secret}, %arg1: i1 {secret.secret}) + func.func @comb(%a: i1, %b: i1) -> () { + %0 = comb.truth_table %a, %b -> 6 : ui4 + return + } +} diff --git a/tools/BUILD b/tools/BUILD index 7bbb92d244..a6af592db0 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -48,6 +48,7 @@ cc_binary( "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Dialect/Secret/Transforms", "@heir//lib/Dialect/TfheRust/IR:Dialect", + "@heir//lib/Transforms/Secretize", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 9113c9f410..3b70db228e 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -14,6 +14,7 @@ #include "include/Dialect/Secret/IR/SecretDialect.h" #include "include/Dialect/Secret/Transforms/Passes.h" #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" +#include "include/Transforms/Secretize/Secretize.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project @@ -182,6 +183,7 @@ int main(int argc, char **argv) { bgv::registerBGVToPolynomialPasses(); comb::registerCombToCGGIPasses(); registerCGGIToTfheRustPasses(); + registerSecretizePasses(); // Register yosys optimizer pipeline if configured. #ifndef HEIR_NO_YOSYS const char *abcEnvPath = std::getenv("HEIR_ABC_BINARY");