Skip to content

Commit

Permalink
Merge pull request #1308 from AlexanderViand-Intel:decouple-pipelines
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720340588
  • Loading branch information
copybara-github committed Jan 28, 2025
2 parents b9500b6 + eb1eb1c commit b64a2c3
Show file tree
Hide file tree
Showing 18 changed files with 103 additions and 82 deletions.
8 changes: 4 additions & 4 deletions docs/content/en/docs/Design/simd.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ and index-based accesses into tensors (e.g., `tensor.extract` and
entire tensors. While its implementation does not depend on any FHE-specific
details or even the Secret dialect, this transformation is likely only useful
when lowering a high-level program to an arithmetic-circuit-based FHE scheme
(e.g., B/FV, BGV, or CKKS). The `-mlir-to-openfhe-bgv` pipeline demonstrates the
intended flow: augmenting a high-level program with `secret` annotations, then
applying the SIMD optimization (and any other high-level optimizations) before
lowering to BGV operations and then exiting to OpenFHE.
(e.g., B/FV, BGV, or CKKS). The `--mlir-to-bgv --scheme-to-openfhe` pipeline
demonstrates the intended flow: augmenting a high-level program with `secret`
annotations, then applying the SIMD optimization (and any other high-level
optimizations) before lowering to BGV operations and then exiting to OpenFHE.

> **Warning** The current SIMD vectorizer pipeline supports only one-dimensional
> tensors. As a workaround, one could reshape all multi-dimensional tensors into
Expand Down
3 changes: 2 additions & 1 deletion docs/content/en/docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ Now we run the `heir-opt` command to optimize and compile the program.

```bash
bazel run //tools:heir-opt -- \
--mlir-to-openfhe-bgv='entry-function=dot_product ciphertext-degree=8' \
--mlir-to-bgv='ciphertext-degree=8'\
--scheme-to-openfhe='entry-function=dot_product' \
$PWD/tests/Examples/openfhe/dot_product_8.mlir > output.mlir
```

Expand Down
6 changes: 2 additions & 4 deletions heir_py/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ def run_compiler(
# TODO(#1162): construct heir-opt pipeline options from decorator
heir_opt_options = [
f"--secretize=function={func_name}",
(
"--mlir-to-openfhe-bgv="
f"entry-function={func_name} ciphertext-degree=32"
),
"--mlir-to-bgv=ciphertext-degree=32",
f"--scheme-to-openfhe=entry-function={func_name}"
]
heir_opt_output = heir_opt.run_binary(
input=mlir_textual,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
if (failed(walkAndValidateTypes<secret::GenericOp>(
module, disallowFloatlike,
"Floating point types are not supported in BGV. Maybe you meant "
"to use a CKKS pipeline like --mlir-to-openfhe-ckks?"))) {
"to use a CKKS pipeline like --mlir-to-ckks?"))) {
signalPassFailure();
return;
}
Expand Down
49 changes: 17 additions & 32 deletions lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ void mlirToRLWEPipeline(OpPassManager &pm,
break;
}
default:
break;
llvm::errs() << "Unsupported RLWE scheme: " << scheme;
exit(EXIT_FAILURE);
}

// Optimize relinearization at mgmt dialect level
Expand Down Expand Up @@ -173,27 +174,11 @@ RLWEPipelineBuilder mlirToRLWEPipelineBuilder(const RLWEScheme scheme) {
};
}

RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(const RLWEScheme scheme) {
return [=](OpPassManager &pm, const MlirToRLWEPipelineOptions &options) {
// lower to RLWE scheme
mlirToRLWEPipeline(pm, options, scheme);

// Convert to (common trivial subset of) LWE
switch (scheme) {
case RLWEScheme::bgvScheme: {
// TODO (#1193): Replace `--bgv-to-lwe` with `--bgv-common-to-lwe`
pm.addPass(bgv::createBGVToLWE());
break;
}
case RLWEScheme::ckksScheme: {
// TODO (#1193): Replace `--ckks-to-lwe` with `--ckks-common-to-lwe`
pm.addPass(ckks::createCKKSToLWE());
break;
}
default:
llvm::errs() << "Unsupported RLWE scheme: " << scheme;
exit(EXIT_FAILURE);
}
BackendPipelineBuilder toOpenFhePipelineBuilder() {
return [=](OpPassManager &pm, const OpenfheOptions &options) {
// Convert the common trivial subset of CKKS/BGV to LWE
pm.addPass(bgv::createBGVToLWE());
pm.addPass(ckks::createCKKSToLWE());

// insert debug handler calls
if (options.debug) {
Expand All @@ -202,7 +187,7 @@ RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(const RLWEScheme scheme) {
pm.addPass(lwe::createAddDebugPort(addDebugPortOptions));
}

// Convert to OpenFHE
// Convert LWE (and scheme-specific CKKS/BGV ops) to OpenFHE
pm.addPass(lwe::createLWEToOpenfhe());

// Simplify, in case the lowering revealed redundancy
Expand All @@ -219,18 +204,18 @@ RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(const RLWEScheme scheme) {
};
}

RLWEPipelineBuilder mlirToLattigoRLWEPipelineBuilder(const RLWEScheme scheme) {
return [=](OpPassManager &pm, const MlirToRLWEPipelineOptions &options) {
LattigoPipelineBuilder mlirToLattigoRLWEPipelineBuilder(
const RLWEScheme scheme) {
return [=](OpPassManager &pm, const LattigoOptions &options) {
// lower to RLWE scheme
MlirToRLWEPipelineOptions overrideOptions;
overrideOptions.entryFunction = options.entryFunction;
overrideOptions.ciphertextDegree = options.ciphertextDegree;
overrideOptions.modulusSwitchBeforeFirstMul =
MlirToRLWEPipelineOptions rlweOptions;
rlweOptions.ciphertextDegree = options.ciphertextDegree;
rlweOptions.modulusSwitchBeforeFirstMul =
options.modulusSwitchBeforeFirstMul;
// use simpler client interface for Lattigo
overrideOptions.usePublicKey = false;
overrideOptions.oneValuePerHelperFn = false;
mlirToRLWEPipeline(pm, overrideOptions, scheme);
rlweOptions.usePublicKey = false;
rlweOptions.oneValuePerHelperFn = false;
mlirToRLWEPipeline(pm, rlweOptions, scheme);

// Convert to (common trivial subset of) LWE
switch (scheme) {
Expand Down
36 changes: 31 additions & 5 deletions lib/Pipelines/ArithmeticPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ void heirSIMDVectorizerPipelineBuilder(OpPassManager &manager);

struct MlirToRLWEPipelineOptions
: public PassPipelineOptions<MlirToRLWEPipelineOptions> {
PassOptions::Option<std::string> entryFunction{
*this, "entry-function", llvm::cl::desc("Entry function to secretize"),
llvm::cl::init("main")};
PassOptions::Option<int> ciphertextDegree{
*this, "ciphertext-degree",
llvm::cl::desc("The degree of the polynomials to use for ciphertexts; "
Expand All @@ -43,16 +40,45 @@ struct MlirToRLWEPipelineOptions
llvm::cl::desc("Modulus switching right before the first multiplication "
"(default to false)"),
llvm::cl::init(false)};
};

struct OpenfheOptions : public PassPipelineOptions<OpenfheOptions> {
PassOptions::Option<std::string> entryFunction{
*this, "entry-function", llvm::cl::desc("Entry function"),
llvm::cl::init("main")};
PassOptions::Option<bool> debug{
*this, "insert-debug-handler-calls",
llvm::cl::desc("Insert function calls to an externally-defined debug "
"function (cf. --lwe-add-debug-port)"),
llvm::cl::init(false)};
};

struct LattigoOptions : public PassPipelineOptions<LattigoOptions> {
PassOptions::Option<int> ciphertextDegree{
*this, "ciphertext-degree",
llvm::cl::desc("The degree of the polynomials to use for ciphertexts; "
"equivalently, the number of messages that can be packed "
"into a single ciphertext."),
llvm::cl::init(1024)};
PassOptions::Option<std::string> entryFunction{
*this, "entry-function", llvm::cl::desc("Entry function"),
llvm::cl::init("main")};
PassOptions::Option<bool> modulusSwitchBeforeFirstMul{
*this, "modulus-switch-before-first-mul",
llvm::cl::desc("Modulus switching right before the first multiplication "
"(default to false)"),
llvm::cl::init(false)};
};

using RLWEPipelineBuilder =
std::function<void(OpPassManager &, const MlirToRLWEPipelineOptions &)>;

using BackendPipelineBuilder =
std::function<void(OpPassManager &, const OpenfheOptions &)>;

using LattigoPipelineBuilder =
std::function<void(OpPassManager &, const LattigoOptions &)>;

void mlirToRLWEPipeline(OpPassManager &pm,
const MlirToRLWEPipelineOptions &options,
RLWEScheme scheme);
Expand All @@ -61,9 +87,9 @@ void mlirToSecretArithmeticPipelineBuilder(OpPassManager &pm);

RLWEPipelineBuilder mlirToRLWEPipelineBuilder(RLWEScheme scheme);

RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(RLWEScheme scheme);
BackendPipelineBuilder toOpenFhePipelineBuilder();

RLWEPipelineBuilder mlirToLattigoRLWEPipelineBuilder(RLWEScheme scheme);
LattigoPipelineBuilder mlirToLattigoRLWEPipelineBuilder(RLWEScheme scheme);

} // namespace mlir::heir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

// This example was converted from the dot_product_8f.mlir example via
//
// heir-opt --mlir-print-ir-before-all --mlir-to-openfhe-bgv='entry-function=dot_product ciphertext-degree=8' tests/Examples/openfhe/dot_product_8f.mlir
// heir-opt --mlir-print-ir-before-all --mlir-to-bgv='ciphertext-degree=8' --scheme-to-openfhe='entry-function=dot_product' tests/Examples/openfhe/dot_product_8f.mlir

module {
func.func @dot_product(%arg0: !secret.secret<tensor<8xf16>> {mgmt.mgmt = #mgmt.mgmt<level = 2>}, %arg1: !secret.secret<tensor<8xf16>> {mgmt.mgmt = #mgmt.mgmt<level = 2>}) -> !secret.secret<tensor<8xf16>> {
%cst = arith.constant dense<9.997550e-02> : tensor<8xf16>
// expected-error@below {{Floating point types are not supported in BGV. Maybe you meant to use a CKKS pipeline like --mlir-to-openfhe-ckks?}}
// expected-error@below {{Floating point types are not supported in BGV. Maybe you meant to use a CKKS pipeline like --mlir-to-ckks?}}
%0 = secret.generic ins(%arg0, %arg1 : !secret.secret<tensor<8xf16>>, !secret.secret<tensor<8xf16>>) attrs = {mgmt.mgmt = #mgmt.mgmt<level = 2, dimension = 3>} {
^bb0(%arg2: tensor<8xf16>, %arg3: tensor<8xf16>):
%17 = arith.mulf %arg2, %arg3 : tensor<8xf16>
Expand Down
35 changes: 28 additions & 7 deletions tests/Examples/openfhe/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ openfhe_end_to_end_test(
openfhe_end_to_end_test(
name = "simple_sum_test",
generated_lib_header = "simple_sum_lib.h",
heir_opt_flags = ["--mlir-to-openfhe-bgv=entry-function=simple_sum ciphertext-degree=32"],
heir_opt_flags = [
"--mlir-to-bgv=ciphertext-degree=32",
"--scheme-to-openfhe=entry-function=simple_sum",
],
mlir_src = "simple_sum.mlir",
tags = ["notap"],
test_src = "simple_sum_test.cpp",
Expand All @@ -34,7 +37,10 @@ openfhe_end_to_end_test(
openfhe_end_to_end_test(
name = "dot_product_8_test",
generated_lib_header = "dot_product_8_lib.h",
heir_opt_flags = ["--mlir-to-openfhe-bgv=entry-function=dot_product ciphertext-degree=8"],
heir_opt_flags = [
"--mlir-to-bgv=ciphertext-degree=8",
"--scheme-to-openfhe=entry-function=dot_product",
],
mlir_src = "dot_product_8.mlir",
tags = ["notap"],
test_src = "dot_product_8_test.cpp",
Expand All @@ -43,7 +49,10 @@ openfhe_end_to_end_test(
openfhe_end_to_end_test(
name = "dot_product_8_debug_test",
generated_lib_header = "dot_product_8_debug_lib.h",
heir_opt_flags = ["--mlir-to-openfhe-bgv=entry-function=dot_product ciphertext-degree=8 insert-debug-handler-calls=true"],
heir_opt_flags = [
"--mlir-to-bgv=ciphertext-degree=8",
"--scheme-to-openfhe=entry-function=dot_product insert-debug-handler-calls=true",
],
mlir_src = "dot_product_8.mlir",
tags = ["notap"],
test_src = "dot_product_8_debug_test.cpp",
Expand All @@ -52,7 +61,10 @@ openfhe_end_to_end_test(
openfhe_end_to_end_test(
name = "box_blur_64x64_test",
generated_lib_header = "box_blur_64x64_lib.h",
heir_opt_flags = ["--mlir-to-openfhe-bgv=entry-function=box_blur ciphertext-degree=4096"],
heir_opt_flags = [
"--mlir-to-bgv=ciphertext-degree=4096",
"--scheme-to-openfhe=entry-function=box_blur",
],
mlir_src = "box_blur_64x64.mlir",
tags = ["notap"],
test_src = "box_blur_test.cpp",
Expand All @@ -61,7 +73,10 @@ openfhe_end_to_end_test(
openfhe_end_to_end_test(
name = "roberts_cross_64x64_test",
generated_lib_header = "roberts_cross_64x64_lib.h",
heir_opt_flags = ["--mlir-to-openfhe-bgv=entry-function=roberts_cross ciphertext-degree=4096"],
heir_opt_flags = [
"--mlir-to-bgv=ciphertext-degree=4096",
"--scheme-to-openfhe=entry-function=roberts_cross",
],
mlir_src = "roberts_cross_64x64.mlir",
tags = ["notap"],
test_src = "roberts_cross_test.cpp",
Expand All @@ -72,7 +87,10 @@ openfhe_end_to_end_test(
openfhe_end_to_end_test(
name = "dot_product_8f_test",
generated_lib_header = "dot_product_8f_lib.h",
heir_opt_flags = ["--mlir-to-openfhe-ckks=entry-function=dot_product ciphertext-degree=8"],
heir_opt_flags = [
"--mlir-to-ckks=ciphertext-degree=8",
"--scheme-to-openfhe=entry-function=dot_product",
],
heir_translate_flags = [],
mlir_src = "dot_product_8f.mlir",
tags = ["notap"],
Expand Down Expand Up @@ -102,7 +120,10 @@ openfhe_end_to_end_test(
openfhe_end_to_end_test(
name = "halevi_shoup_matmul_test",
generated_lib_header = "halevi_shoup_matmul_lib.h",
heir_opt_flags = ["--mlir-to-openfhe-ckks=entry-function=matmul ciphertext-degree=16"],
heir_opt_flags = [
"--mlir-to-ckks=ciphertext-degree=16",
"--scheme-to-openfhe=entry-function=matmul",
],
heir_translate_flags = [
"--openfhe-include-type=source-relative",
],
Expand Down
4 changes: 2 additions & 2 deletions tests/Examples/openfhe/errors/dot_product_8f_type_error.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: heir-opt --verify-diagnostics --mlir-to-openfhe-bgv='entry-function=dot_product ciphertext-degree=8' %s
// RUN: heir-opt --verify-diagnostics --mlir-to-bgv='ciphertext-degree=8' --scheme-to-openfhe='entry-function=dot_product' %s

// expected-error@below {{Floating point types are not supported in BGV. Maybe you meant to use a CKKS pipeline like --mlir-to-openfhe-ckks?}}
// expected-error@below {{Floating point types are not supported in BGV. Maybe you meant to use a CKKS pipeline like --mlir-to-ckks?}}
func.func @dot_product(%arg0: tensor<8xf16> {secret.secret}, %arg1: tensor<8xf16> {secret.secret}) -> f16 {
%c0 = arith.constant 0 : index
%c0_sf16 = arith.constant 0.1 : f16
Expand Down
2 changes: 1 addition & 1 deletion tests/Examples/openfhe/naive_matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// The end to end mlir-to-openfhe-ckks pipeline will automatically rewrite a
// The --mlir-to-ckks pipeline will automatically rewrite a
// matrix multiplication into a Halevi-Shoup diagonalized matrix multiplication.
// This example is preserved to show a speedup of that pass from a naive
// implementation. We start at the CKKS dialect to avoid the automatic
Expand Down
2 changes: 1 addition & 1 deletion tests/Examples/openfhe/simple_sum.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// in the BUILD file for this directory, and the openfhe_end_to_end_test macro
// in test.bzl
//
// heir-opt --mlir-to-openfhe-bgv='entry-function=simple_sum ciphertext-degree=32' %s | bazel-bin/tools/heir-translate --emit-openfhe-pke
// heir-opt --mlir-to-bgv='ciphertext-degree=32' --scheme-to-openfhe='entry-function=simple_sum' %s | bazel-bin/tools/heir-translate --emit-openfhe-pke

func.func @simple_sum(%arg0: tensor<32xi16> {secret.secret}) -> i16 {
%c0 = arith.constant 0 : index
Expand Down
2 changes: 1 addition & 1 deletion tests/Transforms/mlir_to_openfhe_bgv/simple_sum.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --mlir-to-openfhe-bgv='entry-function=simple_sum ciphertext-degree=32' %s | FileCheck %s
// RUN: heir-opt --mlir-to-bgv='ciphertext-degree=32' --scheme-to-openfhe='entry-function=simple_sum' %s | FileCheck %s

// CHECK-LABEL: @simple_sum
// CHECK: openfhe
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-ckks=bootstrap-waterline=3 --mlir-to-openfhe-ckks %s | FileCheck %s
// RUN: heir-opt --mlir-to-secret-arithmetic --secret-insert-mgmt-ckks=bootstrap-waterline=3 --mlir-to-ckks --scheme-to-openfhe %s | FileCheck %s

// CHECK: func.func @bootstrap_waterline
// CHECK: openfhe.bootstrap
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --mlir-to-openfhe-ckks='entry-function=dot_product ciphertext-degree=8' %s | FileCheck %s
// RUN: heir-opt --mlir-to-ckks='ciphertext-degree=8' --scheme-to-openfhe='entry-function=dot_product' %s | FileCheck %s

// CHECK-LABEL: @dot_product
// CHECK-COUNT-3: openfhe.rot
Expand Down
2 changes: 1 addition & 1 deletion tests/Transforms/mlir_to_openfhe_ckks/matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --mlir-print-local-scope --affine-loop-normalize='promote-single-iter=1' --mlir-to-openfhe-ckks %s | FileCheck %s
// RUN: heir-opt --mlir-print-local-scope --affine-loop-normalize='promote-single-iter=1' --mlir-to-ckks --scheme-to-openfhe %s | FileCheck %s

// This pipeline fully loop unrolls the matmul.

Expand Down
2 changes: 1 addition & 1 deletion tests/Transforms/mlir_to_openfhe_ckks/naive_matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --mlir-to-openfhe-ckks='ciphertext-degree=16 entry-function=matmul' %s | heir-translate --emit-openfhe-pke | FileCheck %s
// RUN: heir-opt --mlir-to-ckks='ciphertext-degree=16' --scheme-to-openfhe='entry-function=matmul' %s | heir-translate --emit-openfhe-pke | FileCheck %s

// CHECK-LABEL: std::vector<CiphertextT> matmul(
// CHECK-SAME: CryptoContextT [[v0:[^,]*]],
Expand Down
2 changes: 1 addition & 1 deletion tests/Transforms/mlir_to_openfhe_ckks/simple_sum.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --mlir-to-openfhe-ckks='entry-function=simple_sum ciphertext-degree=32' %s | FileCheck %s
// RUN: heir-opt --mlir-to-ckks='ciphertext-degree=32' --scheme-to-openfhe='entry-function=simple_sum' %s | FileCheck %s

// CHECK-LABEL: @simple_sum
// CHECK-COUNT-6: openfhe.rot
Expand Down
22 changes: 6 additions & 16 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,15 @@ int main(int argc, char **argv) {
"BGV.",
mlirToRLWEPipelineBuilder(mlir::heir::RLWEScheme::bgvScheme));

PassPipelineRegistration<mlir::heir::MlirToRLWEPipelineOptions>(
"mlir-to-openfhe-bgv",
"Convert a func using standard MLIR dialects to FHE using BGV and "
"export "
"to OpenFHE C++ code.",
mlirToOpenFheRLWEPipelineBuilder(mlir::heir::RLWEScheme::bgvScheme));
PassPipelineRegistration<mlir::heir::OpenfheOptions>(
"scheme-to-openfhe",
"Convert code expressed at FHE scheme level to OpenFHE C++ code.",
toOpenFhePipelineBuilder());

PassPipelineRegistration<mlir::heir::MlirToRLWEPipelineOptions>(
PassPipelineRegistration<mlir::heir::LattigoOptions>(
"mlir-to-lattigo-bgv",
"Convert a func using standard MLIR dialects to FHE using BGV and "
"export "
"to Lattigo GO code.",
"export to Lattigo GO code.",
mlirToLattigoRLWEPipelineBuilder(mlir::heir::RLWEScheme::bgvScheme));

PassPipelineRegistration<mlir::heir::MlirToRLWEPipelineOptions>(
Expand All @@ -386,13 +383,6 @@ int main(int argc, char **argv) {
"CKKS.",
mlirToRLWEPipelineBuilder(mlir::heir::RLWEScheme::ckksScheme));

PassPipelineRegistration<mlir::heir::MlirToRLWEPipelineOptions>(
"mlir-to-openfhe-ckks",
"Convert a func using standard MLIR dialects to FHE using CKKS and "
"export "
"to OpenFHE C++ code.",
mlirToOpenFheRLWEPipelineBuilder(mlir::heir::RLWEScheme::ckksScheme));

PassPipelineRegistration<>(
"convert-to-data-oblivious",
"Transforms a native program to data-oblivious program",
Expand Down

0 comments on commit b64a2c3

Please sign in to comment.