From 10c4eb896547edbe779be23b4be2b8223ff1ba64 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 21 Nov 2023 11:12:00 -0800 Subject: [PATCH 01/20] add NoisePropagationInterface --- include/Interfaces/BUILD | 40 ++++++++++++++++++++++++++ include/Interfaces/NoiseInterfaces.h | 17 +++++++++++ include/Interfaces/NoiseInterfaces.td | 41 +++++++++++++++++++++++++++ lib/Interfaces/BUILD | 15 ++++++++++ lib/Interfaces/NoiseInterfaces.cpp | 9 ++++++ 5 files changed, 122 insertions(+) create mode 100644 include/Interfaces/BUILD create mode 100644 include/Interfaces/NoiseInterfaces.h create mode 100644 include/Interfaces/NoiseInterfaces.td create mode 100644 lib/Interfaces/BUILD create mode 100644 lib/Interfaces/NoiseInterfaces.cpp diff --git a/include/Interfaces/BUILD b/include/Interfaces/BUILD new file mode 100644 index 000000000..398b70761 --- /dev/null +++ b/include/Interfaces/BUILD @@ -0,0 +1,40 @@ +# HEIR project-wide interfaces +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +td_library( + name = "NoiseInterfacesTdFiles", + srcs = ["NoiseInterfaces.td"], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "NoiseInterfacesIncGen", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "NoiseInterfaces.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "NoiseInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "NoiseInterfaces.td", + deps = [ + ":NoiseInterfacesTdFiles", + ], +) + +exports_files( + [ + "NoiseInterfaces.h", + ], +) diff --git a/include/Interfaces/NoiseInterfaces.h b/include/Interfaces/NoiseInterfaces.h new file mode 100644 index 000000000..91978f880 --- /dev/null +++ b/include/Interfaces/NoiseInterfaces.h @@ -0,0 +1,17 @@ +#ifndef INCLUDE_INTERFACES_NOISEINTERFACES_H_ +#define INCLUDE_INTERFACES_NOISEINTERFACES_H_ + +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project + +namespace mlir { +namespace heir { + +using SetNoiseFn = function_ref; + +} +} // namespace mlir + +#include "include/Interfaces/NoiseInterfaces.h.inc" + +#endif // INCLUDE_INTERFACES_NOISEINTERFACES_H_ diff --git a/include/Interfaces/NoiseInterfaces.td b/include/Interfaces/NoiseInterfaces.td new file mode 100644 index 000000000..ac864a85b --- /dev/null +++ b/include/Interfaces/NoiseInterfaces.td @@ -0,0 +1,41 @@ +#ifndef INCLUDE_INTERFACES_NOISEINTERFACES_TD_ +#define INCLUDE_INTERFACES_NOISEINTERFACES_TD_ + +include "mlir/IR/OpBase.td" + +def NoisePropagationInterface : OpInterface<"NoisePropagationInterface"> { + let description = [{ + Declares that an operation produces results with noise, and provides an + interface for passes to compute bounds on the noise in the results + from the input noises. + + Here "noise" is defined as the (perhaps upper-bounded) variance of a + Gaussian distribution centered at zero. + }]; + let cppNamespace = "::mlir::heir"; + + let methods = [ + InterfaceMethod<[{ + Infer the noise distribution of the result of this op given the distributions + of its inputs. + + All noise distributions are assumed to be Gaussian centered at zero, and + so the inputs and results are represented by their variances. + + For each result value or block argument (that isn't a branch argument, + since the dataflow analysis handles those case), the method should call + `setValueNoise` with that `Value` as an argument. When `setValueNoise` + is not called for some value, the analysis will raise an error. + + `argNoises` contains one `int64_t` for each operand to the op in ODS + order. Operands that don't have a prior noise associated with them + will have this value set to zero. + }], + "void", "inferResultNoise", (ins + "::llvm::ArrayRef":$argNoises, + "::mlir::heir::SetNoiseFn":$setValueNoise) + >]; +} + + +#endif // INCLUDE_INTERFACES_NOISEINTERFACES_TD_ diff --git a/lib/Interfaces/BUILD b/lib/Interfaces/BUILD new file mode 100644 index 000000000..a49a973dc --- /dev/null +++ b/lib/Interfaces/BUILD @@ -0,0 +1,15 @@ +# HEIR project-wide interface implementations +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "NoiseInterfaces", + srcs = ["NoiseInterfaces.cpp"], + hdrs = ["@heir//include/Interfaces:NoiseInterfaces.h"], + deps = [ + "@heir//include/Interfaces:NoiseInterfacesIncGen", + "@llvm-project//mlir:IR", + ], +) diff --git a/lib/Interfaces/NoiseInterfaces.cpp b/lib/Interfaces/NoiseInterfaces.cpp new file mode 100644 index 000000000..c7d20fffb --- /dev/null +++ b/lib/Interfaces/NoiseInterfaces.cpp @@ -0,0 +1,9 @@ +#ifndef LIB_INTERFACES_NOISEINTERFACES_CPP_ +#define LIB_INTERFACES_NOISEINTERFACES_CPP_ + +#include "include/Interfaces/NoiseInterfaces.h" + +// Import last +#include "include/Interfaces/NoiseInterfaces.cpp.inc" + +#endif // LIB_INTERFACES_NOISEINTERFACES_CPP_ From 38200e5d749f8158c6ca5a902b02621c49c7b1bc Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 21 Nov 2023 11:40:20 -0800 Subject: [PATCH 02/20] add noise propagation interface to CGGI --- include/Dialect/CGGI/IR/BUILD | 1 + include/Dialect/CGGI/IR/CGGIOps.h | 1 + include/Dialect/CGGI/IR/CGGIOps.td | 4 +++- lib/Dialect/CGGI/IR/BUILD | 3 +++ lib/Dialect/CGGI/IR/CGGIOps.cpp | 6 ++++++ 5 files changed, 14 insertions(+), 1 deletion(-) create mode 100644 lib/Dialect/CGGI/IR/CGGIOps.cpp diff --git a/include/Dialect/CGGI/IR/BUILD b/include/Dialect/CGGI/IR/BUILD index 247639231..a06d7c9b9 100644 --- a/include/Dialect/CGGI/IR/BUILD +++ b/include/Dialect/CGGI/IR/BUILD @@ -23,6 +23,7 @@ td_library( # include from the heir-root to enable fully-qualified include-paths includes = ["../../../.."], deps = [ + "@heir//include/Interfaces:NoiseInterfacesTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", diff --git a/include/Dialect/CGGI/IR/CGGIOps.h b/include/Dialect/CGGI/IR/CGGIOps.h index b0e7a842d..892bb54b4 100644 --- a/include/Dialect/CGGI/IR/CGGIOps.h +++ b/include/Dialect/CGGI/IR/CGGIOps.h @@ -3,6 +3,7 @@ #include "include/Dialect/CGGI/IR/CGGIDialect.h" #include "include/Dialect/LWE/IR/LWETypes.h" +#include "include/Interfaces/NoiseInterfaces.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project diff --git a/include/Dialect/CGGI/IR/CGGIOps.td b/include/Dialect/CGGI/IR/CGGIOps.td index 66e374fda..bf7a4d537 100644 --- a/include/Dialect/CGGI/IR/CGGIOps.td +++ b/include/Dialect/CGGI/IR/CGGIOps.td @@ -5,6 +5,7 @@ include "include/Dialect/CGGI/IR/CGGIDialect.td" include "include/Dialect/Polynomial/IR/PolynomialAttributes.td" include "include/Dialect/LWE/IR/LWETypes.td" +include "include/Interfaces/NoiseInterfaces.td" include "mlir/IR/OpBase.td" include "mlir/IR/BuiltinAttributes.td" @@ -12,7 +13,8 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" class CGGI_Op traits = []> : - Op { + Op]> { let assemblyFormat = [{ `(` operands `)` attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results)) }]; diff --git a/lib/Dialect/CGGI/IR/BUILD b/lib/Dialect/CGGI/IR/BUILD index 84f110290..37b336bac 100644 --- a/lib/Dialect/CGGI/IR/BUILD +++ b/lib/Dialect/CGGI/IR/BUILD @@ -7,6 +7,7 @@ cc_library( name = "Dialect", srcs = [ "CGGIDialect.cpp", + "CGGIOps.cpp", ], hdrs = [ "@heir//include/Dialect/CGGI/IR:CGGIAttributes.h", @@ -18,6 +19,8 @@ cc_library( "@heir//include/Dialect/CGGI/IR:dialect_inc_gen", "@heir//include/Dialect/CGGI/IR:ops_inc_gen", "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", + "@heir//lib/Interfaces:NoiseInterfaces", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", diff --git a/lib/Dialect/CGGI/IR/CGGIOps.cpp b/lib/Dialect/CGGI/IR/CGGIOps.cpp new file mode 100644 index 000000000..ca59219e4 --- /dev/null +++ b/lib/Dialect/CGGI/IR/CGGIOps.cpp @@ -0,0 +1,6 @@ +#ifndef LIB_DIALECT_CGGI_IR_CGGIOPS_CPP_ +#define LIB_DIALECT_CGGI_IR_CGGIOPS_CPP_ + +#include "include/Dialect/CGGI/IR/CGGIOps.h" + +#endif // LIB_DIALECT_CGGI_IR_CGGIOPS_CPP_ From c254e45d2a38b42f3e680173947b85d6ae3f7dc3 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 21 Nov 2023 15:40:16 -0800 Subject: [PATCH 03/20] add first-pass implementation of post-bootstrap noise model --- lib/Dialect/CGGI/IR/CGGIOps.cpp | 92 +++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/lib/Dialect/CGGI/IR/CGGIOps.cpp b/lib/Dialect/CGGI/IR/CGGIOps.cpp index ca59219e4..6ef7b8dda 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.cpp +++ b/lib/Dialect/CGGI/IR/CGGIOps.cpp @@ -3,4 +3,96 @@ #include "include/Dialect/CGGI/IR/CGGIOps.h" +#include "include/Dialect/CGGI/IR/CGGIAttributes.h" +#include "include/Dialect/LWE/IR/LWEAttributes.h" +#include "include/Interfaces/NoiseInterfaces.h" + +namespace mlir { +namespace heir { +namespace cggi { + +unsigned maxPerDigitDecompositionError(unsigned baseLog, unsigned numLevels, + unsigned ctBitWidth) { + // FIXME: this needs verification; I struggled to parse what was said in the + // CGGI paper, as well as the original DM paper, so I relied on my own + // analysis in https://jeremykun.com/2022/08/29/key-switching-in-lwe/ + // It aligns roughly with the error analysis in Theorem 4.1 of + // https://eprint.iacr.org/2018/421, but using a different perspective + // on the "precision" parameter t in that paper. + + // maxLevels is the number L such that B^L = lwe_cmod + // a.k.a., L * log2(B) = cmod_bitwidth + // This should be an exact division, since the LWE cmod is always supposed to + // be a power of two. + unsigned maxLevels = ctBitWidth / baseLog; + unsigned lowestLevel = maxLevels - numLevels; + // Regardless of whether the approximation is signed or not, the max error you + // can get per digit is to be off by B-1. + unsigned approximationPerDigitMaxError = (1 << baseLog) - 1; + return (unsigned)pow(approximationPerDigitMaxError, lowestLevel - 1); +} + +/// This function represents one noise model for the output of bootstrap in the +/// simplest CGGI implementation. It is an upper bound estimate of the variance +/// of a ciphertext post bootstrap (including the key switch op). Follows the +/// formula in https://eprint.iacr.org/2018/421, Theorem 6.3. +/// +/// Notes: +/// +/// In the paper, the key-switching key gadget is binary. Here it has an +/// arbitrary base and number of levels. +/// +/// Signed decompositions are used for the gadgets, leading to a multiplicative +/// factor of two difference between the quality "beta" (max digit size) of the +/// gadget and the chosen parameter 2**base_log. +int64_t bootstrapOutputNoise(CGGIParamsAttr attr) { + lwe::LWEParamsAttr lweParams = attr.getLweParams(); + lwe::RLWEParamsAttr rlweParams = attr.getRlweParams(); + unsigned bskNoiseVariance = attr.getBskNoiseVariance(); + unsigned kskNoiseVariance = attr.getKskNoiseVariance(); + + // Mirroring the notation in https://eprint.iacr.org/2018/421, Theorem 6.3. + unsigned logq = lweParams.getCmod().getValue().getBitWidth(); + unsigned n = lweParams.getDimension(); + unsigned k = rlweParams.getDimension(); + unsigned N = rlweParams.getPolyDegree(); + unsigned l = attr.getBskGadgetNumLevels(); + // Beta is the max absolute value of a digit of the signed decomposition + unsigned beta = (1 << attr.getBskGadgetBaseLog()) / 2; + + // Epsilon is the max per-digit error of the approximation introduced by + // having fewer levels in the gadget key. + // FIXME: this needs verification. I think it's the same sort of error as the + // key switching key sampleApproxError below. + unsigned epsilon = maxPerDigitDecompositionError( + attr.getBskGadgetBaseLog(), attr.getBskGadgetNumLevels(), logq); + unsigned externalProductTerm = + (n * (k + 1) * l * N * beta * beta * bskNoiseVariance + + n * (1 + k * N) * epsilon * epsilon); + + // largestDigit depends on a signed decomposition. + unsigned largestDigit = (1 << attr.getKskGadgetBaseLog()) / 2; + unsigned kskSampleApproxError = maxPerDigitDecompositionError( + attr.getKskGadgetBaseLog(), attr.getKskGadgetNumLevels(), logq); + unsigned keySwitchingTerm = + (attr.getKskGadgetNumLevels() * largestDigit * kskNoiseVariance + + n * kskSampleApproxError); + return externalProductTerm + keySwitchingTerm; +} + +void AndOp::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) { + // Have to get a CGGIParams instance here somehow. + // auto resultNoise = bootstrapOutputNoise(...); +} + +void OrOp::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} +void XorOp::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} +void NotOp::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} +void Lut3Op::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} +void Lut2Op::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} + +} // namespace cggi +} // namespace heir +} // namespace mlir + #endif // LIB_DIALECT_CGGI_IR_CGGIOPS_CPP_ From 2793bf480fea17fc8f966e4bccef540a00467011 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 21 Nov 2023 16:06:55 -0800 Subject: [PATCH 04/20] implement noise propagation as if CGGIParams was an available attribute on every op --- lib/Dialect/CGGI/IR/CGGIOps.cpp | 59 +++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/CGGI/IR/CGGIOps.cpp b/lib/Dialect/CGGI/IR/CGGIOps.cpp index 6ef7b8dda..03835eb81 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.cpp +++ b/lib/Dialect/CGGI/IR/CGGIOps.cpp @@ -45,8 +45,8 @@ unsigned maxPerDigitDecompositionError(unsigned baseLog, unsigned numLevels, /// Signed decompositions are used for the gadgets, leading to a multiplicative /// factor of two difference between the quality "beta" (max digit size) of the /// gadget and the chosen parameter 2**base_log. -int64_t bootstrapOutputNoise(CGGIParamsAttr attr) { - lwe::LWEParamsAttr lweParams = attr.getLweParams(); +int64_t bootstrapOutputNoise(CGGIParamsAttr attr, + lwe::LWEParamsAttr lweParams) { lwe::RLWEParamsAttr rlweParams = attr.getRlweParams(); unsigned bskNoiseVariance = attr.getBskNoiseVariance(); unsigned kskNoiseVariance = attr.getKskNoiseVariance(); @@ -80,16 +80,55 @@ int64_t bootstrapOutputNoise(CGGIParamsAttr attr) { return externalProductTerm + keySwitchingTerm; } -void AndOp::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) { - // Have to get a CGGIParams instance here somehow. - // auto resultNoise = bootstrapOutputNoise(...); +void handleSingleResultOp(Operation *op, Value ctValue, + SetNoiseFn setValueNoise) { + auto lweParams = + cast(ctValue.getType()).getLweParams(); + if (!lweParams) { + op->emitOpError() << "lwe_params must be set on the input values to run " + "noise propagation."; + return; + } + + auto attrs = op->getAttrDictionary(); + if (!attrs.contains("cggi_params")) { + op->emitOpError() << "cggi_params must be set to run noise propagation."; + return; + } + auto cggiParams = llvm::cast(attrs.get("cggi_params")); + setValueNoise(op->getResult(0), bootstrapOutputNoise(cggiParams, lweParams)); +} + +void AndOp::inferResultNoise(llvm::ArrayRef argNoises, + SetNoiseFn setValueNoise) { + return handleSingleResultOp(getOperation(), getLhs(), setValueNoise); +} + +void OrOp::inferResultNoise(llvm::ArrayRef argNoises, + SetNoiseFn setValueNoise) { + return handleSingleResultOp(getOperation(), getLhs(), setValueNoise); +} + +void XorOp::inferResultNoise(llvm::ArrayRef argNoises, + SetNoiseFn setValueNoise) { + return handleSingleResultOp(getOperation(), getLhs(), setValueNoise); +} + +void Lut3Op::inferResultNoise(llvm::ArrayRef argNoises, + SetNoiseFn setValueNoise) { + return handleSingleResultOp(getOperation(), getA(), setValueNoise); } -void OrOp::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} -void XorOp::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} -void NotOp::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} -void Lut3Op::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} -void Lut2Op::inferResultNoise(llvm::ArrayRef, SetNoiseFn setValueNoise) {} +void Lut2Op::inferResultNoise(llvm::ArrayRef argNoises, + SetNoiseFn setValueNoise) { + return handleSingleResultOp(getOperation(), getA(), setValueNoise); +} + +void NotOp::inferResultNoise(llvm::ArrayRef argNoises, + SetNoiseFn setValueNoise) { + // This one doesn't use bootstrap, no error change + setValueNoise(getInput(), argNoises[0]); +} } // namespace cggi } // namespace heir From e6b0f74a2036bc41a4756c3cc2226ab634515824 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 22 Nov 2023 12:00:37 -0800 Subject: [PATCH 05/20] add lwe.add op with noise inference --- include/Dialect/LWE/IR/BUILD | 1 + include/Dialect/LWE/IR/LWEDialect.h | 2 ++ include/Dialect/LWE/IR/LWEOps.td | 15 ++++++++++++++- lib/Dialect/LWE/IR/BUILD | 2 ++ lib/Dialect/LWE/IR/LWEDialect.cpp | 6 ++++++ 5 files changed, 25 insertions(+), 1 deletion(-) diff --git a/include/Dialect/LWE/IR/BUILD b/include/Dialect/LWE/IR/BUILD index 82be4e027..8f2794b02 100644 --- a/include/Dialect/LWE/IR/BUILD +++ b/include/Dialect/LWE/IR/BUILD @@ -27,6 +27,7 @@ td_library( # include from the heir-root to enable fully-qualified include-paths includes = ["../../../.."], deps = [ + "@heir//include/Interfaces:NoiseInterfacesTdFiles", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", diff --git a/include/Dialect/LWE/IR/LWEDialect.h b/include/Dialect/LWE/IR/LWEDialect.h index bf202f20c..01d2f9bb0 100644 --- a/include/Dialect/LWE/IR/LWEDialect.h +++ b/include/Dialect/LWE/IR/LWEDialect.h @@ -1,8 +1,10 @@ #ifndef HEIR_INCLUDE_DIALECT_LWE_IR_LWEDIALECT_H_ #define HEIR_INCLUDE_DIALECT_LWE_IR_LWEDIALECT_H_ +#include "include/Interfaces/NoiseInterfaces.h" #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project // Generated headers (block clang-format from messing up order) #include "include/Dialect/LWE/IR/LWEDialect.h.inc" diff --git a/include/Dialect/LWE/IR/LWEOps.td b/include/Dialect/LWE/IR/LWEOps.td index aa8018981..d8e236e6b 100644 --- a/include/Dialect/LWE/IR/LWEOps.td +++ b/include/Dialect/LWE/IR/LWEOps.td @@ -3,10 +3,11 @@ include "include/Dialect/LWE/IR/LWEDialect.td" include "include/Dialect/LWE/IR/LWETypes.td" +include "include/Interfaces/NoiseInterfaces.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/CommonAttrConstraints.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/CommonAttrConstraints.td" class LWE_Op traits = []> : Op { @@ -59,4 +60,16 @@ def LWE_TrivialEncryptOp: LWE_Op<"trivial_encrypt", [Pure]> { let hasVerifier = 1; } +def LWE_AddOp : LWE_Op<"add", [ + Pure, + SameOperandsAndResultType, + DeclareOpInterfaceMethods +]> { + let arguments = (ins + LWECiphertext:$lhs, + LWECiphertext:$rhs + ); + let results = (outs LWECiphertext:$output); +} + #endif // HEIR_INCLUDE_DIALECT_LWE_IR_LWEOPS_TD_ diff --git a/lib/Dialect/LWE/IR/BUILD b/lib/Dialect/LWE/IR/BUILD index 03b2b9617..3add90f59 100644 --- a/lib/Dialect/LWE/IR/BUILD +++ b/lib/Dialect/LWE/IR/BUILD @@ -20,6 +20,8 @@ cc_library( "@heir//include/Dialect/LWE/IR:ops_inc_gen", "@heir//include/Dialect/LWE/IR:types_inc_gen", "@heir//lib/Dialect/Polynomial/IR:Dialect", + "@heir//lib/Dialect/Polynomial/IR:Polynomial", + "@heir//lib/Interfaces:NoiseInterfaces", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/lib/Dialect/LWE/IR/LWEDialect.cpp b/lib/Dialect/LWE/IR/LWEDialect.cpp index 416130a67..b2a11cb3d 100644 --- a/lib/Dialect/LWE/IR/LWEDialect.cpp +++ b/lib/Dialect/LWE/IR/LWEDialect.cpp @@ -6,6 +6,7 @@ #include "include/Dialect/LWE/IR/LWEOps.h" #include "include/Dialect/LWE/IR/LWETypes.h" #include "include/Dialect/Polynomial/IR/PolynomialTypes.h" +#include "include/Interfaces/NoiseInterfaces.h" #include "llvm/include/llvm/ADT/STLFunctionalExtras.h" // from @llvm-project #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project @@ -180,6 +181,11 @@ LogicalResult TrivialEncryptOp::verify() { return success(); } +void AddOp::inferResultNoise(llvm::ArrayRef argNoises, + SetNoiseFn setValueNoise) { + return setValueNoise(getResult(), argNoises[0] + argNoises[1]); +} + } // namespace lwe } // namespace heir } // namespace mlir From 27ac979bf2805d6e42b619db91142bb326e6d627 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 22 Nov 2023 14:47:22 -0800 Subject: [PATCH 06/20] add NoisePropagationAnalysis and Variance --- include/Analysis/NoisePropagation/BUILD | 12 ++++ .../NoisePropagationAnalysis.h | 50 +++++++++++++++ include/Analysis/NoisePropagation/Variance.h | 50 +++++++++++++++ include/Interfaces/NoiseInterfaces.h | 8 ++- include/Interfaces/NoiseInterfaces.td | 2 +- lib/Analysis/NoisePropagation/BUILD | 27 ++++++++ .../NoisePropagationAnalysis.cpp | 63 +++++++++++++++++++ lib/Analysis/NoisePropagation/Variance.cpp | 14 +++++ lib/Interfaces/BUILD | 1 + 9 files changed, 223 insertions(+), 4 deletions(-) create mode 100644 include/Analysis/NoisePropagation/BUILD create mode 100644 include/Analysis/NoisePropagation/NoisePropagationAnalysis.h create mode 100644 include/Analysis/NoisePropagation/Variance.h create mode 100644 lib/Analysis/NoisePropagation/BUILD create mode 100644 lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp create mode 100644 lib/Analysis/NoisePropagation/Variance.cpp diff --git a/include/Analysis/NoisePropagation/BUILD b/include/Analysis/NoisePropagation/BUILD new file mode 100644 index 000000000..d9f8e8966 --- /dev/null +++ b/include/Analysis/NoisePropagation/BUILD @@ -0,0 +1,12 @@ +# NoisePropagationAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + [ + "NoisePropagationAnalysis.h", + "Variance.h", + ], +) diff --git a/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h b/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h new file mode 100644 index 000000000..05e0838d6 --- /dev/null +++ b/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h @@ -0,0 +1,50 @@ +#ifndef INCLUDE_ANALYSIS_NOISEPROPAGATION_NOISEPROPAGATIONANALYSIS_H_ +#define INCLUDE_ANALYSIS_NOISEPROPAGATION_NOISEPROPAGATIONANALYSIS_H_ + +#include "include/Analysis/NoisePropagation/Variance.h" +#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project + +namespace mlir { +namespace heir { + +/// This lattice element represents the noise distribution of an SSA value. +class VarianceLattice : public dataflow::Lattice { + public: + using Lattice::Lattice; +}; + +/// Noise propagation analysis determines a noise bound for SSA values, +/// represented by the variance of a symmetric Gaussian distribution. This +/// analysis propagates noise across operations that implement +/// `NoisePropagationInterface`, but does not support propagation for SSA +/// values that represent loop bounds or induction variables. It can be viewed +/// as a simplified port of IntegerRangeAnalysis. +class NoisePropagationAnalysis + : public dataflow::SparseForwardDataFlowAnalysis { + public: + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; + + void setToEntryState(VarianceLattice *lattice) override { + // At an entry point, we have no information about the noise. + propagateIfChanged(lattice, lattice->join(Variance(std::nullopt))); + } + + void visitOperation(Operation *op, ArrayRef operands, + ArrayRef results) override; + + /// Visit block arguments or operation results of an operation with region + /// control-flow for which values are not defined by region control-flow. This + /// function calls `InferIntRangeInterface` to provide values for block + /// arguments. + void visitNonControlFlowArguments(Operation *op, + const RegionSuccessor &successor, + ArrayRef argLattices, + unsigned firstIndex) override; +}; + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_NOISEPROPAGATION_NOISEPROPAGATIONANALYSIS_H_ diff --git a/include/Analysis/NoisePropagation/Variance.h b/include/Analysis/NoisePropagation/Variance.h new file mode 100644 index 000000000..304c9dfa8 --- /dev/null +++ b/include/Analysis/NoisePropagation/Variance.h @@ -0,0 +1,50 @@ +#ifndef INCLUDE_ANALYSIS_NOISEPROPAGATION_VARIANCE_H_ +#define INCLUDE_ANALYSIS_NOISEPROPAGATION_VARIANCE_H_ + +#include +#include +#include +#include + +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project + +namespace mlir { +namespace heir { + +/// A class representing an optional variance of a noise distribution. +class Variance { + public: + /// Create an integer value range lattice value. + Variance(std::optional value = std::nullopt) : value(value) {} + + bool isKnown() const { return value.has_value(); } + + const int64_t &getValue() const { + assert(isKnown()); + return *value; + } + + bool operator==(const Variance &rhs) const { return value == rhs.value; } + + /// This method represents how to choose a noise from one of two possible + /// branches, when either could be possible. In the case of FHE, we must + /// assume the worse case, so take the max. + static Variance join(const Variance &lhs, const Variance &rhs) { + if (!lhs.isKnown()) return rhs; + if (!rhs.isKnown()) return lhs; + return Variance{std::max(lhs.getValue(), rhs.getValue())}; + } + + void print(llvm::raw_ostream &os) const { os << value; } + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const Variance &variance); + + private: + std::optional value; +}; + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_ANALYSIS_NOISEPROPAGATION_VARIANCE_H_ diff --git a/include/Interfaces/NoiseInterfaces.h b/include/Interfaces/NoiseInterfaces.h index 91978f880..d7fed7bf6 100644 --- a/include/Interfaces/NoiseInterfaces.h +++ b/include/Interfaces/NoiseInterfaces.h @@ -1,15 +1,17 @@ #ifndef INCLUDE_INTERFACES_NOISEINTERFACES_H_ #define INCLUDE_INTERFACES_NOISEINTERFACES_H_ -#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project +#include "include/Analysis/NoisePropagation/Variance.h" +#include "mlir/include/mlir/IR/OpDefinition.h" // trom @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project namespace mlir { namespace heir { -using SetNoiseFn = function_ref; +// Variance is a type defined by NoisePropagationAnalysis +using SetNoiseFn = function_ref; -} +} // namespace heir } // namespace mlir #include "include/Interfaces/NoiseInterfaces.h.inc" diff --git a/include/Interfaces/NoiseInterfaces.td b/include/Interfaces/NoiseInterfaces.td index ac864a85b..308135299 100644 --- a/include/Interfaces/NoiseInterfaces.td +++ b/include/Interfaces/NoiseInterfaces.td @@ -32,7 +32,7 @@ def NoisePropagationInterface : OpInterface<"NoisePropagationInterface"> { will have this value set to zero. }], "void", "inferResultNoise", (ins - "::llvm::ArrayRef":$argNoises, + "::llvm::ArrayRef":$argNoises, "::mlir::heir::SetNoiseFn":$setValueNoise) >]; } diff --git a/lib/Analysis/NoisePropagation/BUILD b/lib/Analysis/NoisePropagation/BUILD new file mode 100644 index 000000000..2b5cc6935 --- /dev/null +++ b/lib/Analysis/NoisePropagation/BUILD @@ -0,0 +1,27 @@ +# NoisePropagationAnalysis analysis pass +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "NoisePropagationAnalysis", + srcs = ["NoisePropagationAnalysis.cpp"], + hdrs = ["@heir//include/Analysis/NoisePropagation:NoisePropagationAnalysis.h"], + deps = [ + ":Variance", + "@heir//lib/Interfaces:NoiseInterfaces", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "Variance", + srcs = ["Variance.cpp"], + hdrs = ["@heir//include/Analysis/NoisePropagation:Variance.h"], + deps = [ + "@llvm-project//llvm:Support", + ], +) diff --git a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp new file mode 100644 index 000000000..6466973e6 --- /dev/null +++ b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp @@ -0,0 +1,63 @@ +#include "include/Analysis/NoisePropagation/NoisePropagationAnalysis.h" + +#include "include/Interfaces/NoiseInterfaces.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project + +#define DEBUG_TYPE "NoisePropagationAnalysis" + +namespace mlir { +namespace heir { + +void NoisePropagationAnalysis::visitOperation( + Operation *op, ArrayRef operands, + ArrayRef results) { + // If the lattice on any operand is unknown, bail out. + if (llvm::any_of(operands, [](const VarianceLattice *lattice) { + return !lattice->getValue().isKnown(); + })) { + return; + } + + auto noisePropagationOp = dyn_cast(op); + if (!noisePropagationOp) return setAllToEntryStates(results); + + LLVM_DEBUG(llvm::dbgs() << "Propagating noise for " << *op << "\n"); + SmallVector argRanges(llvm::map_range( + operands, [](const VarianceLattice *val) { return val->getValue(); })); + + auto joinCallback = [&](Value value, const Variance &variance) { + auto result = dyn_cast(value); + if (!result) return; + assert(llvm::is_contained(op->getResults(), result)); + + LLVM_DEBUG(llvm::dbgs() << "Inferred noise " << variance << "\n"); + VarianceLattice *lattice = results[result.getResultNumber()]; + Variance oldRange = lattice->getValue(); + ChangeResult changed = lattice->join(Variance{variance}); + + // If the result is yielded, then the best we can do is check to see + // if the op producing this value has deterministic noise. If so, + // we can propagate that noise. Otherwise, we must assume the worst + // case scenario of unknown noise. + bool isYieldedResult = llvm::any_of(value.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + // FIXME: add DeterministicNoise trait + if (isYieldedResult && oldRange.isKnown() && + !(lattice->getValue() == oldRange)) { + LLVM_DEBUG( + llvm::dbgs() + << "Non-deterministic noise-propagating op passed to a region " + "terminator. Assuming loop result and marking noise unknown\n"); + changed |= lattice->join(Variance(std::nullopt)); + } + propagateIfChanged(lattice, changed); + }; + + noisePropagationOp.inferResultNoise(argRanges, joinCallback); +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Analysis/NoisePropagation/Variance.cpp b/lib/Analysis/NoisePropagation/Variance.cpp new file mode 100644 index 000000000..ee06a9264 --- /dev/null +++ b/lib/Analysis/NoisePropagation/Variance.cpp @@ -0,0 +1,14 @@ +#include "include/Analysis/NoisePropagation/Variance.h" + +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project + +namespace mlir { +namespace heir { + +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Variance &variance) { + if (!variance.isKnown()) return os << "unknown"; + return os << variance.getValue(); +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Interfaces/BUILD b/lib/Interfaces/BUILD index a49a973dc..6c549bfd0 100644 --- a/lib/Interfaces/BUILD +++ b/lib/Interfaces/BUILD @@ -10,6 +10,7 @@ cc_library( hdrs = ["@heir//include/Interfaces:NoiseInterfaces.h"], deps = [ "@heir//include/Interfaces:NoiseInterfacesIncGen", + "@heir//lib/Analysis/NoisePropagation:Variance", "@llvm-project//mlir:IR", ], ) From d18bef8dd4351d7992cadd3d7067255f3180e25b Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 22 Nov 2023 14:52:10 -0800 Subject: [PATCH 07/20] migrate ops to Variance class --- include/Analysis/NoisePropagation/Variance.h | 2 ++ lib/Dialect/CGGI/IR/BUILD | 2 ++ lib/Dialect/CGGI/IR/CGGIOps.cpp | 16 +++++++++------- lib/Dialect/LWE/IR/BUILD | 1 + lib/Dialect/LWE/IR/LWEDialect.cpp | 11 +++++++++-- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/include/Analysis/NoisePropagation/Variance.h b/include/Analysis/NoisePropagation/Variance.h index 304c9dfa8..daec9ce2d 100644 --- a/include/Analysis/NoisePropagation/Variance.h +++ b/include/Analysis/NoisePropagation/Variance.h @@ -14,6 +14,8 @@ namespace heir { /// A class representing an optional variance of a noise distribution. class Variance { public: + static Variance unknown() { return Variance(); } + /// Create an integer value range lattice value. Variance(std::optional value = std::nullopt) : value(value) {} diff --git a/lib/Dialect/CGGI/IR/BUILD b/lib/Dialect/CGGI/IR/BUILD index 37b336bac..de47739e9 100644 --- a/lib/Dialect/CGGI/IR/BUILD +++ b/lib/Dialect/CGGI/IR/BUILD @@ -18,6 +18,8 @@ cc_library( "@heir//include/Dialect/CGGI/IR:attributes_inc_gen", "@heir//include/Dialect/CGGI/IR:dialect_inc_gen", "@heir//include/Dialect/CGGI/IR:ops_inc_gen", + "@heir//include/Dialect/Polynomial/IR:attributes_inc_gen", + "@heir//lib/Analysis/NoisePropagation:Variance", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Polynomial/IR:PolynomialAttributes", "@heir//lib/Interfaces:NoiseInterfaces", diff --git a/lib/Dialect/CGGI/IR/CGGIOps.cpp b/lib/Dialect/CGGI/IR/CGGIOps.cpp index 03835eb81..358b52dc1 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.cpp +++ b/lib/Dialect/CGGI/IR/CGGIOps.cpp @@ -3,6 +3,7 @@ #include "include/Dialect/CGGI/IR/CGGIOps.h" +#include "include/Analysis/NoisePropagation/Variance.h" #include "include/Dialect/CGGI/IR/CGGIAttributes.h" #include "include/Dialect/LWE/IR/LWEAttributes.h" #include "include/Interfaces/NoiseInterfaces.h" @@ -96,35 +97,36 @@ void handleSingleResultOp(Operation *op, Value ctValue, return; } auto cggiParams = llvm::cast(attrs.get("cggi_params")); - setValueNoise(op->getResult(0), bootstrapOutputNoise(cggiParams, lweParams)); + setValueNoise(op->getResult(0), + Variance(bootstrapOutputNoise(cggiParams, lweParams))); } -void AndOp::inferResultNoise(llvm::ArrayRef argNoises, +void AndOp::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { return handleSingleResultOp(getOperation(), getLhs(), setValueNoise); } -void OrOp::inferResultNoise(llvm::ArrayRef argNoises, +void OrOp::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { return handleSingleResultOp(getOperation(), getLhs(), setValueNoise); } -void XorOp::inferResultNoise(llvm::ArrayRef argNoises, +void XorOp::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { return handleSingleResultOp(getOperation(), getLhs(), setValueNoise); } -void Lut3Op::inferResultNoise(llvm::ArrayRef argNoises, +void Lut3Op::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { return handleSingleResultOp(getOperation(), getA(), setValueNoise); } -void Lut2Op::inferResultNoise(llvm::ArrayRef argNoises, +void Lut2Op::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { return handleSingleResultOp(getOperation(), getA(), setValueNoise); } -void NotOp::inferResultNoise(llvm::ArrayRef argNoises, +void NotOp::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { // This one doesn't use bootstrap, no error change setValueNoise(getInput(), argNoises[0]); diff --git a/lib/Dialect/LWE/IR/BUILD b/lib/Dialect/LWE/IR/BUILD index 3add90f59..5aa6783fc 100644 --- a/lib/Dialect/LWE/IR/BUILD +++ b/lib/Dialect/LWE/IR/BUILD @@ -13,6 +13,7 @@ cc_library( "@heir//include/Dialect/LWE/IR:LWEDialect.h", "@heir//include/Dialect/LWE/IR:LWEOps.h", "@heir//include/Dialect/LWE/IR:LWETypes.h", + "@heir//lib/Analysis/NoisePropagation:Variance", ], deps = [ "@heir//include/Dialect/LWE/IR:attributes_inc_gen", diff --git a/lib/Dialect/LWE/IR/LWEDialect.cpp b/lib/Dialect/LWE/IR/LWEDialect.cpp index b2a11cb3d..31c2c9321 100644 --- a/lib/Dialect/LWE/IR/LWEDialect.cpp +++ b/lib/Dialect/LWE/IR/LWEDialect.cpp @@ -2,6 +2,7 @@ #include +#include "include/Analysis/NoisePropagation/Variance.h" #include "include/Dialect/LWE/IR/LWEAttributes.h" #include "include/Dialect/LWE/IR/LWEOps.h" #include "include/Dialect/LWE/IR/LWETypes.h" @@ -181,9 +182,15 @@ LogicalResult TrivialEncryptOp::verify() { return success(); } -void AddOp::inferResultNoise(llvm::ArrayRef argNoises, +void AddOp::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { - return setValueNoise(getResult(), argNoises[0] + argNoises[1]); + Variance result; + if (!argNoises[0].isKnown() || !argNoises[1].isKnown()) + result = Variance::unknown(); + else + result = Variance{argNoises[0].getValue() + argNoises[1].getValue()}; + + return setValueNoise(getResult(), result); } } // namespace lwe From 8a0f2687a27911bffd603cbaa90bc827966c0f46 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 22 Nov 2023 15:19:34 -0800 Subject: [PATCH 08/20] add template for propagate_noise test --- tests/noise/propagate_noise.mlir | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 tests/noise/propagate_noise.mlir diff --git a/tests/noise/propagate_noise.mlir b/tests/noise/propagate_noise.mlir new file mode 100644 index 000000000..646d93cfe --- /dev/null +++ b/tests/noise/propagate_noise.mlir @@ -0,0 +1,21 @@ +// RUN: heir-opt --cggi-set-default-parameters --lwe-set-default-parameters --validate-noise %s + +#encoding = #lwe.bit_field_encoding +#poly = #polynomial.polynomial<1 + x**1024> +!plaintext = !lwe.lwe_plaintext +!ciphertext = !lwe.lwe_ciphertext + +// CHECK-LABEL: @test_adds_attribute +func.func @test_adds_attribute(%arg0 : !ciphertext) -> !ciphertext { + %0 = arith.constant 0 : i1 + %1 = arith.constant 1 : i1 + %2 = lwe.encode %0 { encoding = #encoding }: i1 to !plaintext + %3 = lwe.encode %1 { encoding = #encoding }: i1 to !plaintext + // CHECK: lwe.trivial_encrypt + %4 = lwe.trivial_encrypt %2 { params = #params } : !plaintext to !ciphertext + // CHECK: lwe.trivial_encrypt + %5 = lwe.trivial_encrypt %3 { params = #params } : !plaintext to !ciphertext + %6 = lwe.add %4, %5 : !ciphertext + %6 = cggi.lut3 (%arg0, %6, %5) {lookup_table = 127 : index} : !ciphertext + return %4 : !ciphertext +} From 7138a29819ba82150625656d07ed4073b545337d Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 22 Nov 2023 17:40:41 -0800 Subject: [PATCH 09/20] add LWE set default parameters --- include/Dialect/LWE/IR/LWEOps.td | 3 +++ .../{propagate_noise.mlir => validate_noise.mlir} | 10 +++++----- 2 files changed, 8 insertions(+), 5 deletions(-) rename tests/noise/{propagate_noise.mlir => validate_noise.mlir} (67%) diff --git a/include/Dialect/LWE/IR/LWEOps.td b/include/Dialect/LWE/IR/LWEOps.td index d8e236e6b..cd75a89b3 100644 --- a/include/Dialect/LWE/IR/LWEOps.td +++ b/include/Dialect/LWE/IR/LWEOps.td @@ -70,6 +70,9 @@ def LWE_AddOp : LWE_Op<"add", [ LWECiphertext:$rhs ); let results = (outs LWECiphertext:$output); + let assemblyFormat = [{ + operands attr-dict `:` qualified(type($output)) + }]; } #endif // HEIR_INCLUDE_DIALECT_LWE_IR_LWEOPS_TD_ diff --git a/tests/noise/propagate_noise.mlir b/tests/noise/validate_noise.mlir similarity index 67% rename from tests/noise/propagate_noise.mlir rename to tests/noise/validate_noise.mlir index 646d93cfe..39c530f67 100644 --- a/tests/noise/propagate_noise.mlir +++ b/tests/noise/validate_noise.mlir @@ -3,7 +3,7 @@ #encoding = #lwe.bit_field_encoding #poly = #polynomial.polynomial<1 + x**1024> !plaintext = !lwe.lwe_plaintext -!ciphertext = !lwe.lwe_ciphertext +!ciphertext = !lwe.lwe_ciphertext // CHECK-LABEL: @test_adds_attribute func.func @test_adds_attribute(%arg0 : !ciphertext) -> !ciphertext { @@ -12,10 +12,10 @@ func.func @test_adds_attribute(%arg0 : !ciphertext) -> !ciphertext { %2 = lwe.encode %0 { encoding = #encoding }: i1 to !plaintext %3 = lwe.encode %1 { encoding = #encoding }: i1 to !plaintext // CHECK: lwe.trivial_encrypt - %4 = lwe.trivial_encrypt %2 { params = #params } : !plaintext to !ciphertext + %4 = lwe.trivial_encrypt %2 : !plaintext to !ciphertext // CHECK: lwe.trivial_encrypt - %5 = lwe.trivial_encrypt %3 { params = #params } : !plaintext to !ciphertext + %5 = lwe.trivial_encrypt %3 : !plaintext to !ciphertext %6 = lwe.add %4, %5 : !ciphertext - %6 = cggi.lut3 (%arg0, %6, %5) {lookup_table = 127 : index} : !ciphertext - return %4 : !ciphertext + %7 = cggi.lut3 (%arg0, %6, %5) {lookup_table = 127 : index} : !ciphertext + return %7 : !ciphertext } From bfdff75d7c18e35fffcd61de54ec01dbbc93170f Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 23 Nov 2023 00:15:20 -0800 Subject: [PATCH 10/20] add ValidateNoise pass --- include/Transforms/ValidateNoise/BUILD | 35 ++++++++++ .../Transforms/ValidateNoise/ValidateNoise.h | 18 +++++ .../Transforms/ValidateNoise/ValidateNoise.td | 17 +++++ lib/Transforms/ValidateNoise/BUILD | 21 ++++++ .../ValidateNoise/ValidateNoise.cpp | 65 +++++++++++++++++++ tools/BUILD | 1 + tools/heir-opt.cpp | 2 + 7 files changed, 159 insertions(+) create mode 100644 include/Transforms/ValidateNoise/BUILD create mode 100644 include/Transforms/ValidateNoise/ValidateNoise.h create mode 100644 include/Transforms/ValidateNoise/ValidateNoise.td create mode 100644 lib/Transforms/ValidateNoise/BUILD create mode 100644 lib/Transforms/ValidateNoise/ValidateNoise.cpp diff --git a/include/Transforms/ValidateNoise/BUILD b/include/Transforms/ValidateNoise/BUILD new file mode 100644 index 000000000..67201a56d --- /dev/null +++ b/include/Transforms/ValidateNoise/BUILD @@ -0,0 +1,35 @@ +# ValidateNoise tablegen and headers. + +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files([ + "ValidateNoise.h", +]) + +gentbl_cc_library( + name = "pass_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=ValidateNoise", + ], + "ValidateNoise.h.inc", + ), + ( + ["-gen-pass-doc"], + "ValidateNoise.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ValidateNoise.td", + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/include/Transforms/ValidateNoise/ValidateNoise.h b/include/Transforms/ValidateNoise/ValidateNoise.h new file mode 100644 index 000000000..85d51855e --- /dev/null +++ b/include/Transforms/ValidateNoise/ValidateNoise.h @@ -0,0 +1,18 @@ +#ifndef INCLUDE_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_H_ +#define INCLUDE_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DECL +#include "include/Transforms/ValidateNoise/ValidateNoise.h.inc" + +#define GEN_PASS_REGISTRATION +#include "include/Transforms/ValidateNoise/ValidateNoise.h.inc" + +} // namespace heir +} // namespace mlir + +#endif // INCLUDE_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_H_ diff --git a/include/Transforms/ValidateNoise/ValidateNoise.td b/include/Transforms/ValidateNoise/ValidateNoise.td new file mode 100644 index 000000000..37b81d398 --- /dev/null +++ b/include/Transforms/ValidateNoise/ValidateNoise.td @@ -0,0 +1,17 @@ +#ifndef INCLUDE_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_TD_ +#define INCLUDE_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_TD_ + +include "mlir/Pass/PassBase.td" + +def ValidateNoise : Pass<"validate-noise"> { + let summary = "Validate the FHE noise growth in the IR"; + let description = [{ + This pass applies a noise propagation analysis to the IR and checks that + the noise does not grow beyond a feasible maximum, based on the paraneters + chosen for the FHE dialects involved. + + Requires the ops in the IR to implement `NoisePropagationInterface`. + }]; +} + +#endif // INCLUDE_TRANSFORMS_VALIDATENOISE_VALIDATENOISE_TD_ diff --git a/lib/Transforms/ValidateNoise/BUILD b/lib/Transforms/ValidateNoise/BUILD new file mode 100644 index 000000000..171f22f5c --- /dev/null +++ b/lib/Transforms/ValidateNoise/BUILD @@ -0,0 +1,21 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "ValidateNoise", + srcs = ["ValidateNoise.cpp"], + hdrs = [ + "@heir//include/Transforms/ValidateNoise:ValidateNoise.h", + ], + deps = [ + "@heir//include/Transforms/ValidateNoise:pass_inc_gen", + "@heir//lib/Analysis/NoisePropagation:NoisePropagationAnalysis", + "@heir//lib/Analysis/NoisePropagation:Variance", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp new file mode 100644 index 000000000..bcf2a1950 --- /dev/null +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -0,0 +1,65 @@ +#include "include/Transforms/ValidateNoise/ValidateNoise.h" + +#include "lib/Analysis/NoisePropagation/NoisePropagationAnalysis.h" +#include "lib/Analysis/NoisePropagation/Variance.h" +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project +#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +namespace mlir { +namespace heir { + +#define GEN_PASS_DEF_VALIDATENOISE +#include "include/Transforms/ValidateNoise/ValidateNoise.h.inc" + +struct ValidateNoise : impl::ValidateNoiseBase { + using ValidateNoiseBase::ValidateNoiseBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + + DataFlowSolver solver; + // FIXME: do I still need DeadCodeAnalysis? + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(module))) { + getOperation()->emitOpError() << "Failed to run the analysis.\n"; + signalPassFailure(); + return; + } + + auto result = module->walk([&](Operation *op) { + const VarianceLattice *opRange = + solver.lookupState(op->getResult(0)); + // FIXME: should be OK for some places to now know the noise. + if (!opRange || !opRange->getValue().isKnown()) { + op->emitOpError() << "Found op without a known noise variance; did the " + "analysis fail?"; + return WalkResult::interrupt(); + } + + int64_t var = opRange->getValue().getValue(); + int64_t maxNoise = 0; // FIXME: infer from the parameters? + if (var > maxNoise) { + op->emitOpError() << "Found op after which the noise exceeds the " + "allowable maximum of " + << maxNoise << "; it was: " << var << "\n"; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + getOperation()->emitOpError() + << "Detected error in the noise analysis.\n"; + signalPassFailure(); + } + } +}; + +} // namespace heir +} // namespace mlir diff --git a/tools/BUILD b/tools/BUILD index b227a8642..b32275840 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -51,6 +51,7 @@ cc_binary( "@heir//lib/Dialect/Secret/Transforms", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Transforms/Secretize", + "@heir//lib/Transforms/ValidateNoise", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 909065421..80bf89f9c 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -17,6 +17,7 @@ #include "include/Dialect/Secret/Transforms/Passes.h" #include "include/Dialect/TfheRust/IR/TfheRustDialect.h" #include "include/Transforms/Secretize/Secretize.h" +#include "include/Transforms/ValidateNoise/ValidateNoise.h" #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project @@ -180,6 +181,7 @@ int main(int argc, char **argv) { registerAllPasses(); // Custom passes in HEIR + registerValidateNoise(); cggi::registerCGGIPasses(); lwe::registerLWEPasses(); secret::registerSecretPasses(); From 598f47960d6d0256ce0c046c2143c673699dda1c Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 28 Nov 2023 17:12:47 -0800 Subject: [PATCH 11/20] fix build errors --- .../Analysis/NoisePropagation/NoisePropagationAnalysis.h | 9 --------- lib/Transforms/ValidateNoise/ValidateNoise.cpp | 6 +++--- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h b/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h index 05e0838d6..0abfc2d86 100644 --- a/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h +++ b/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h @@ -33,15 +33,6 @@ class NoisePropagationAnalysis void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override; - - /// Visit block arguments or operation results of an operation with region - /// control-flow for which values are not defined by region control-flow. This - /// function calls `InferIntRangeInterface` to provide values for block - /// arguments. - void visitNonControlFlowArguments(Operation *op, - const RegionSuccessor &successor, - ArrayRef argLattices, - unsigned firstIndex) override; }; } // namespace heir diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp index bcf2a1950..a732e51de 100644 --- a/lib/Transforms/ValidateNoise/ValidateNoise.cpp +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -1,7 +1,7 @@ #include "include/Transforms/ValidateNoise/ValidateNoise.h" -#include "lib/Analysis/NoisePropagation/NoisePropagationAnalysis.h" -#include "lib/Analysis/NoisePropagation/Variance.h" +#include "include/Analysis/NoisePropagation/NoisePropagationAnalysis.h" +#include "include/Analysis/NoisePropagation/Variance.h" #include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project @@ -19,7 +19,7 @@ struct ValidateNoise : impl::ValidateNoiseBase { using ValidateNoiseBase::ValidateNoiseBase; void runOnOperation() override { - MLIRContext *context = &getContext(); + auto *module = getOperation(); DataFlowSolver solver; // FIXME: do I still need DeadCodeAnalysis? From 00bafd454bf75bf18f79f01d418658ab541402d5 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 28 Nov 2023 17:38:27 -0800 Subject: [PATCH 12/20] add hasDeterministicResultNoise to interface --- include/Dialect/LWE/IR/LWEOps.td | 5 ++++- include/Interfaces/NoiseInterfaces.td | 8 ++++++++ .../NoisePropagation/NoisePropagationAnalysis.cpp | 2 +- lib/Dialect/CGGI/IR/CGGIOps.cpp | 7 +++++++ lib/Dialect/LWE/IR/LWEDialect.cpp | 8 ++++++++ 5 files changed, 28 insertions(+), 2 deletions(-) diff --git a/include/Dialect/LWE/IR/LWEOps.td b/include/Dialect/LWE/IR/LWEOps.td index cd75a89b3..f49389c27 100644 --- a/include/Dialect/LWE/IR/LWEOps.td +++ b/include/Dialect/LWE/IR/LWEOps.td @@ -41,7 +41,10 @@ def LWE_EncodeOp : LWE_Op<"encode", [Pure]> { let hasVerifier = 1; } -def LWE_TrivialEncryptOp: LWE_Op<"trivial_encrypt", [Pure]> { +def LWE_TrivialEncryptOp: LWE_Op<"trivial_encrypt", [ + Pure, + DeclareOpInterfaceMethods +]> { let summary = "Create a trivial encryption of a plaintext."; let arguments = (ins diff --git a/include/Interfaces/NoiseInterfaces.td b/include/Interfaces/NoiseInterfaces.td index 308135299..461ef3a85 100644 --- a/include/Interfaces/NoiseInterfaces.td +++ b/include/Interfaces/NoiseInterfaces.td @@ -34,6 +34,14 @@ def NoisePropagationInterface : OpInterface<"NoisePropagationInterface"> { "void", "inferResultNoise", (ins "::llvm::ArrayRef":$argNoises, "::mlir::heir::SetNoiseFn":$setValueNoise) + + >, + InterfaceMethod<[{ + Returns true if the noise in the result op is independent of the noise in + its inputs. This is suitable for ops like bootstrap and initial + encryption. + }], + "bool", "hasDeterministicResultNoise", (ins) >]; } diff --git a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp index 6466973e6..241368678 100644 --- a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp +++ b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp @@ -44,7 +44,7 @@ void NoisePropagationAnalysis::visitOperation( bool isYieldedResult = llvm::any_of(value.getUsers(), [](Operation *op) { return op->hasTrait(); }); - // FIXME: add DeterministicNoise trait + // FIXME: incorporate deterministic noise check if (isYieldedResult && oldRange.isKnown() && !(lattice->getValue() == oldRange)) { LLVM_DEBUG( diff --git a/lib/Dialect/CGGI/IR/CGGIOps.cpp b/lib/Dialect/CGGI/IR/CGGIOps.cpp index 358b52dc1..184a3996f 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.cpp +++ b/lib/Dialect/CGGI/IR/CGGIOps.cpp @@ -132,6 +132,13 @@ void NotOp::inferResultNoise(llvm::ArrayRef argNoises, setValueNoise(getInput(), argNoises[0]); } +bool AndOp::hasDeterministicResultNoise() { return true; } +bool OrOp::hasDeterministicResultNoise() { return true; } +bool XorOp::hasDeterministicResultNoise() { return true; } +bool Lut3Op::hasDeterministicResultNoise() { return true; } +bool Lut2Op::hasDeterministicResultNoise() { return true; } +bool NotOp::hasDeterministicResultNoise() { return false; } + } // namespace cggi } // namespace heir } // namespace mlir diff --git a/lib/Dialect/LWE/IR/LWEDialect.cpp b/lib/Dialect/LWE/IR/LWEDialect.cpp index 31c2c9321..c479c46ee 100644 --- a/lib/Dialect/LWE/IR/LWEDialect.cpp +++ b/lib/Dialect/LWE/IR/LWEDialect.cpp @@ -193,6 +193,14 @@ void AddOp::inferResultNoise(llvm::ArrayRef argNoises, return setValueNoise(getResult(), result); } +void TrivialEncryptOp::inferResultNoise(llvm::ArrayRef argNoises, + SetNoiseFn setValueNoise) { + return setValueNoise(getResult(), Variance(0)); +} + +bool AddOp::hasDeterministicResultNoise() { return false; } +bool TrivialEncryptOp::hasDeterministicResultNoise() { return true; } + } // namespace lwe } // namespace heir } // namespace mlir From a122c6cc57d9253676783c9cb99f65122db3b44e Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 29 Nov 2023 07:29:53 -0800 Subject: [PATCH 13/20] rename from deterministic to argument-independent --- include/Interfaces/NoiseInterfaces.td | 6 +++--- lib/Dialect/CGGI/IR/CGGIOps.cpp | 12 ++++++------ lib/Dialect/LWE/IR/LWEDialect.cpp | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/include/Interfaces/NoiseInterfaces.td b/include/Interfaces/NoiseInterfaces.td index 461ef3a85..d1fdda850 100644 --- a/include/Interfaces/NoiseInterfaces.td +++ b/include/Interfaces/NoiseInterfaces.td @@ -16,8 +16,8 @@ def NoisePropagationInterface : OpInterface<"NoisePropagationInterface"> { let methods = [ InterfaceMethod<[{ - Infer the noise distribution of the result of this op given the distributions - of its inputs. + Infers the noise distribution of the result of this op given the + distributions of its inputs. All noise distributions are assumed to be Gaussian centered at zero, and so the inputs and results are represented by their variances. @@ -41,7 +41,7 @@ def NoisePropagationInterface : OpInterface<"NoisePropagationInterface"> { its inputs. This is suitable for ops like bootstrap and initial encryption. }], - "bool", "hasDeterministicResultNoise", (ins) + "bool", "hasArgumentIndependentResultNoise", (ins) >]; } diff --git a/lib/Dialect/CGGI/IR/CGGIOps.cpp b/lib/Dialect/CGGI/IR/CGGIOps.cpp index 184a3996f..29bfedf9b 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.cpp +++ b/lib/Dialect/CGGI/IR/CGGIOps.cpp @@ -132,12 +132,12 @@ void NotOp::inferResultNoise(llvm::ArrayRef argNoises, setValueNoise(getInput(), argNoises[0]); } -bool AndOp::hasDeterministicResultNoise() { return true; } -bool OrOp::hasDeterministicResultNoise() { return true; } -bool XorOp::hasDeterministicResultNoise() { return true; } -bool Lut3Op::hasDeterministicResultNoise() { return true; } -bool Lut2Op::hasDeterministicResultNoise() { return true; } -bool NotOp::hasDeterministicResultNoise() { return false; } +bool AndOp::hasArgumentIndependentResultNoise() { return true; } +bool OrOp::hasArgumentIndependentResultNoise() { return true; } +bool XorOp::hasArgumentIndependentResultNoise() { return true; } +bool Lut3Op::hasArgumentIndependentResultNoise() { return true; } +bool Lut2Op::hasArgumentIndependentResultNoise() { return true; } +bool NotOp::hasArgumentIndependentResultNoise() { return false; } } // namespace cggi } // namespace heir diff --git a/lib/Dialect/LWE/IR/LWEDialect.cpp b/lib/Dialect/LWE/IR/LWEDialect.cpp index c479c46ee..edb967fb9 100644 --- a/lib/Dialect/LWE/IR/LWEDialect.cpp +++ b/lib/Dialect/LWE/IR/LWEDialect.cpp @@ -198,8 +198,8 @@ void TrivialEncryptOp::inferResultNoise(llvm::ArrayRef argNoises, return setValueNoise(getResult(), Variance(0)); } -bool AddOp::hasDeterministicResultNoise() { return false; } -bool TrivialEncryptOp::hasDeterministicResultNoise() { return true; } +bool AddOp::hasArgumentIndependentResultNoise() { return false; } +bool TrivialEncryptOp::hasArgumentIndependentResultNoise() { return true; } } // namespace lwe } // namespace heir From a4e9e1e58821d0955ee2e604e1c55d8fcbf43c44 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 29 Nov 2023 07:30:22 -0800 Subject: [PATCH 14/20] incorporate arg-independent noise method into passes --- .../NoisePropagationAnalysis.cpp | 19 +++++++---- .../ValidateNoise/ValidateNoise.cpp | 34 ++++++++++++++++--- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp index 241368678..071756482 100644 --- a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp +++ b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp @@ -37,19 +37,24 @@ void NoisePropagationAnalysis::visitOperation( Variance oldRange = lattice->getValue(); ChangeResult changed = lattice->join(Variance{variance}); - // If the result is yielded, then the best we can do is check to see - // if the op producing this value has deterministic noise. If so, - // we can propagate that noise. Otherwise, we must assume the worst - // case scenario of unknown noise. + // If the result is yielded, then the best we can do is check to see if the + // op producing this value has argument-independent noise. If so, we can + // propagate that noise. Otherwise, we must assume the worst case scenario + // of unknown noise. bool isYieldedResult = llvm::any_of(value.getUsers(), [](Operation *op) { return op->hasTrait(); }); - // FIXME: incorporate deterministic noise check + // The check !(lattice->getValue() == oldRange) would fail if the noise + // depends on its arguments, but we add the extra check for + // hasArgumentIndependentResultNoise to make it easier for humans to + // determine where in the codebase one should look for stuff related to + // this method. if (isYieldedResult && oldRange.isKnown() && - !(lattice->getValue() == oldRange)) { + !(lattice->getValue() == oldRange) && + !noisePropagationOp.hasArgumentIndependentResultNoise()) { LLVM_DEBUG( llvm::dbgs() - << "Non-deterministic noise-propagating op passed to a region " + << "Non-constant noise-propagating op passed to a region " "terminator. Assuming loop result and marking noise unknown\n"); changed |= lattice->join(Variance(std::nullopt)); } diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp index a732e51de..5bb7a6461 100644 --- a/lib/Transforms/ValidateNoise/ValidateNoise.cpp +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -2,6 +2,7 @@ #include "include/Analysis/NoisePropagation/NoisePropagationAnalysis.h" #include "include/Analysis/NoisePropagation/Variance.h" +#include "include/Interfaces/NoiseInterfaces.h" #include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project @@ -34,11 +35,36 @@ struct ValidateNoise : impl::ValidateNoiseBase { auto result = module->walk([&](Operation *op) { const VarianceLattice *opRange = solver.lookupState(op->getResult(0)); - // FIXME: should be OK for some places to now know the noise. + // It's OK for some places to not know the noise, so long as the only + // user of that value is a bootstrap-like op. if (!opRange || !opRange->getValue().isKnown()) { - op->emitOpError() << "Found op without a known noise variance; did the " - "analysis fail?"; - return WalkResult::interrupt(); + // One might expect a check for hasSingleUse, but there could + // potentially be multiple downstream users, each applying a different + // kind of programmable bootstrap to compute different functions, so we + // loop over all users. + for (auto result : op->getResults()) { + for (auto *user : result.getUsers()) { + auto noisePropagationOp = dyn_cast(user); + // If the cast fails, then we can still proceed. The user could be + // control flow like a func.call or a loop. In such cases, the + // dataflow solver should propagate the value through the control + // flow already, so we don't need to check it. It could also be a + // decryption op, which doesn't implement the interface but is + // valid. + if (noisePropagationOp && + !noisePropagationOp.hasArgumentIndependentResultNoise()) { + op->emitOpError() + << "Found op unknown noise variance, and it has a user with " + "non-constant noise propagation. This can happen when an " + "SSA value is part of control flow, such as a loop or an " + "entrypoint to a function with multiple callers. In such " + "cases, an extra bootstrap is required to ensure the " + "value does not exceed its noise bound, or the control " + "flow must be removed."; + return WalkResult::interrupt(); + } + } + } } int64_t var = opRange->getValue().getValue(); From fb9d14be5a03607efc77b8503cfc46e5c8f7855d Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 29 Nov 2023 08:12:39 -0800 Subject: [PATCH 15/20] add trivial noiseless test --- tests/validate_noise/BUILD | 13 +++++++++++++ tests/{noise => validate_noise}/validate_noise.mlir | 12 ++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) create mode 100644 tests/validate_noise/BUILD rename tests/{noise => validate_noise}/validate_noise.mlir (71%) diff --git a/tests/validate_noise/BUILD b/tests/validate_noise/BUILD new file mode 100644 index 000000000..6c9032391 --- /dev/null +++ b/tests/validate_noise/BUILD @@ -0,0 +1,13 @@ +load("//bazel:lit.bzl", "glob_lit_tests") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +glob_lit_tests( + name = "all_tests", + data = ["@heir//tests:test_utilities"], + driver = "@heir//tests:run_lit.sh", + test_file_exts = ["mlir"], +) diff --git a/tests/noise/validate_noise.mlir b/tests/validate_noise/validate_noise.mlir similarity index 71% rename from tests/noise/validate_noise.mlir rename to tests/validate_noise/validate_noise.mlir index 39c530f67..9064c473b 100644 --- a/tests/noise/validate_noise.mlir +++ b/tests/validate_noise/validate_noise.mlir @@ -5,17 +5,17 @@ !plaintext = !lwe.lwe_plaintext !ciphertext = !lwe.lwe_ciphertext -// CHECK-LABEL: @test_adds_attribute -func.func @test_adds_attribute(%arg0 : !ciphertext) -> !ciphertext { +// TODO(https://github.com/google/heir/issues/296): use lwe.encrypt with +// realistic initial noise. + +// CHECK-LABEL: @test_defaults_are_valid_for_single_add +func.func @test_defaults_are_valid_for_single_add() -> !ciphertext { %0 = arith.constant 0 : i1 %1 = arith.constant 1 : i1 %2 = lwe.encode %0 { encoding = #encoding }: i1 to !plaintext %3 = lwe.encode %1 { encoding = #encoding }: i1 to !plaintext - // CHECK: lwe.trivial_encrypt %4 = lwe.trivial_encrypt %2 : !plaintext to !ciphertext - // CHECK: lwe.trivial_encrypt %5 = lwe.trivial_encrypt %3 : !plaintext to !ciphertext %6 = lwe.add %4, %5 : !ciphertext - %7 = cggi.lut3 (%arg0, %6, %5) {lookup_table = 127 : index} : !ciphertext - return %7 : !ciphertext + return %6 : !ciphertext } From d0eeb2aa3b9382e60b6c1ab55d1b881ca3c9c067 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 29 Nov 2023 08:13:05 -0800 Subject: [PATCH 16/20] add debug statements, loop over all results --- .../NoisePropagationAnalysis.cpp | 22 +++-- .../ValidateNoise/ValidateNoise.cpp | 99 +++++++++++-------- 2 files changed, 73 insertions(+), 48 deletions(-) diff --git a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp index 071756482..f302fdac5 100644 --- a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp +++ b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp @@ -13,17 +13,25 @@ namespace heir { void NoisePropagationAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { - // If the lattice on any operand is unknown, bail out. - if (llvm::any_of(operands, [](const VarianceLattice *lattice) { + auto noisePropagationOp = dyn_cast(op); + if (!noisePropagationOp) return setAllToEntryStates(results); + + LLVM_DEBUG(llvm::dbgs() << "Propagating noise for " << noisePropagationOp + << "\n"); + + // Ops with argument-independent noise propagation can work with unknown noise + // arguments, but others cannot. + if (!noisePropagationOp.hasArgumentIndependentResultNoise() && + llvm::any_of(operands, [](const VarianceLattice *lattice) { return !lattice->getValue().isKnown(); })) { - return; + LLVM_DEBUG(llvm::dbgs() + << "Op " << noisePropagationOp->getName() + << "with argument-dependent noise propagation encountered input " + "with unknown noise. Marking result noise as unknown.\n"); + return setAllToEntryStates(results); } - auto noisePropagationOp = dyn_cast(op); - if (!noisePropagationOp) return setAllToEntryStates(results); - - LLVM_DEBUG(llvm::dbgs() << "Propagating noise for " << *op << "\n"); SmallVector argRanges(llvm::map_range( operands, [](const VarianceLattice *val) { return val->getValue(); })); diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp index 5bb7a6461..133a89ec0 100644 --- a/lib/Transforms/ValidateNoise/ValidateNoise.cpp +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -3,12 +3,14 @@ #include "include/Analysis/NoisePropagation/NoisePropagationAnalysis.h" #include "include/Analysis/NoisePropagation/Variance.h" #include "include/Interfaces/NoiseInterfaces.h" -#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project -#include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project// from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" // from @llvm-projectject #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project #include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project + +#define DEBUG_TYPE "ValidateNoise" namespace mlir { namespace heir { @@ -33,49 +35,64 @@ struct ValidateNoise : impl::ValidateNoiseBase { } auto result = module->walk([&](Operation *op) { - const VarianceLattice *opRange = - solver.lookupState(op->getResult(0)); - // It's OK for some places to not know the noise, so long as the only - // user of that value is a bootstrap-like op. - if (!opRange || !opRange->getValue().isKnown()) { - // One might expect a check for hasSingleUse, but there could - // potentially be multiple downstream users, each applying a different - // kind of programmable bootstrap to compute different functions, so we - // loop over all users. - for (auto result : op->getResults()) { - for (auto *user : result.getUsers()) { - auto noisePropagationOp = dyn_cast(user); - // If the cast fails, then we can still proceed. The user could be - // control flow like a func.call or a loop. In such cases, the - // dataflow solver should propagate the value through the control - // flow already, so we don't need to check it. It could also be a - // decryption op, which doesn't implement the interface but is - // valid. - if (noisePropagationOp && - !noisePropagationOp.hasArgumentIndependentResultNoise()) { - op->emitOpError() - << "Found op unknown noise variance, and it has a user with " - "non-constant noise propagation. This can happen when an " - "SSA value is part of control flow, such as a loop or an " - "entrypoint to a function with multiple callers. In such " - "cases, an extra bootstrap is required to ensure the " - "value does not exceed its noise bound, or the control " - "flow must be removed."; - return WalkResult::interrupt(); + for (Value result : op->getResults()) { + const VarianceLattice *opRange = + solver.lookupState(result); + if (!opRange) { + LLVM_DEBUG(llvm::dbgs() + << "Solver did not assign noise to op " << *op << "\n"); + return WalkResult::advance(); + } + LLVM_DEBUG(llvm::dbgs() << "Found noise " << opRange->getValue() + << " at op " << *op << "\n"); + // It's OK for some places to not know the noise, so long as the only + // user of that value is a bootstrap-like op. + if (!opRange->getValue().isKnown()) { + // One might expect a check for hasSingleUse, but there could + // potentially be multiple downstream users, each applying a different + // kind of programmable bootstrap to compute different functions, so + // we loop over all users. + for (auto result : op->getResults()) { + for (Operation *user : result.getUsers()) { + auto noisePropagationOp = + dyn_cast(user); + // If the cast fails, then we can still proceed. The user could be + // control flow like a func.call or a loop. In such cases, the + // dataflow solver should propagate the value through the control + // flow already, so we don't need to check it. It could also be a + // decryption op, which doesn't implement the interface but is + // valid. + if (noisePropagationOp && + !noisePropagationOp.hasArgumentIndependentResultNoise()) { + user->emitOpError() + << "uses SSA value with unknown noise variance, but the op " + "has non-constant noise propagation. This can happen " + "when " + "an SSA value is part of control flow, such as a loop " + "or " + "an entrypoint to a function with multiple callers. In " + "such cases, an extra bootstrap is required to ensure " + "the " + "value does not exceed its noise bound, or the control " + "flow must be removed. SSA value was: \n\n" + << result << "\n\n"; + return WalkResult::interrupt(); + } } } + + return WalkResult::advance(); } - } - int64_t var = opRange->getValue().getValue(); - int64_t maxNoise = 0; // FIXME: infer from the parameters? - if (var > maxNoise) { - op->emitOpError() << "Found op after which the noise exceeds the " - "allowable maximum of " - << maxNoise << "; it was: " << var << "\n"; - return WalkResult::interrupt(); + int64_t var = opRange->getValue().getValue(); + int64_t maxNoise = 0; // FIXME: infer from the parameters? + if (var > maxNoise) { + op->emitOpError() << "Found op after which the noise exceeds the " + "allowable maximum of " + << maxNoise << "; it was: " << var << "\n"; + return WalkResult::interrupt(); + } } - return WalkResult::advance(); }); From a3b5a98e5926212dfdd1ecca725d8a619f4ed8b0 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 29 Nov 2023 08:34:31 -0800 Subject: [PATCH 17/20] improve diagnostics and debugs --- include/Analysis/NoisePropagation/Variance.h | 4 ++++ lib/Analysis/NoisePropagation/BUILD | 1 + lib/Analysis/NoisePropagation/Variance.cpp | 5 +++++ lib/Transforms/ValidateNoise/ValidateNoise.cpp | 17 ++++++++++------- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/include/Analysis/NoisePropagation/Variance.h b/include/Analysis/NoisePropagation/Variance.h index daec9ce2d..cb382e073 100644 --- a/include/Analysis/NoisePropagation/Variance.h +++ b/include/Analysis/NoisePropagation/Variance.h @@ -7,6 +7,7 @@ #include #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project namespace mlir { namespace heir { @@ -42,6 +43,9 @@ class Variance { friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Variance &variance); + friend Diagnostic &operator<<(Diagnostic &diagnostic, + const Variance &variance); + private: std::optional value; }; diff --git a/lib/Analysis/NoisePropagation/BUILD b/lib/Analysis/NoisePropagation/BUILD index 2b5cc6935..b3fdfc619 100644 --- a/lib/Analysis/NoisePropagation/BUILD +++ b/lib/Analysis/NoisePropagation/BUILD @@ -23,5 +23,6 @@ cc_library( hdrs = ["@heir//include/Analysis/NoisePropagation:Variance.h"], deps = [ "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", ], ) diff --git a/lib/Analysis/NoisePropagation/Variance.cpp b/lib/Analysis/NoisePropagation/Variance.cpp index ee06a9264..f065f2530 100644 --- a/lib/Analysis/NoisePropagation/Variance.cpp +++ b/lib/Analysis/NoisePropagation/Variance.cpp @@ -10,5 +10,10 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Variance &variance) { return os << variance.getValue(); } +Diagnostic &operator<<(Diagnostic &diagnostic, const Variance &variance) { + if (!variance.isKnown()) return diagnostic << "unknown"; + return diagnostic << variance.getValue(); +} + } // namespace heir } // namespace mlir diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp index 133a89ec0..ae2a03bb4 100644 --- a/lib/Transforms/ValidateNoise/ValidateNoise.cpp +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -25,7 +25,7 @@ struct ValidateNoise : impl::ValidateNoiseBase { auto *module = getOperation(); DataFlowSolver solver; - // FIXME: do I still need DeadCodeAnalysis? + // The dataflow solver needs DeadCodeAnalysis to run the other analyses solver.load(); solver.load(); if (failed(solver.initializeAndRun(module))) { @@ -35,16 +35,19 @@ struct ValidateNoise : impl::ValidateNoiseBase { } auto result = module->walk([&](Operation *op) { - for (Value result : op->getResults()) { + for (OpResult result : op->getResults()) { const VarianceLattice *opRange = solver.lookupState(result); if (!opRange) { - LLVM_DEBUG(llvm::dbgs() - << "Solver did not assign noise to op " << *op << "\n"); - return WalkResult::advance(); + LLVM_DEBUG(op->emitOpError() + << "Solver did not assign noise to op, suggesting the " + "noise propagation analysis did not run properly or at " + "all."); + return WalkResult::interrupt(); } - LLVM_DEBUG(llvm::dbgs() << "Found noise " << opRange->getValue() - << " at op " << *op << "\n"); + LLVM_DEBUG(op->emitRemark() + << "Found noise " << (opRange->getValue()) + << " for op result " << result.getResultNumber()); // It's OK for some places to not know the noise, so long as the only // user of that value is a bootstrap-like op. if (!opRange->getValue().isKnown()) { From c46d28dc9411527832772f8b5f3650ac91ee3543 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 29 Nov 2023 12:07:01 -0800 Subject: [PATCH 18/20] cop out: use smaller noise params to get validate_noise to pass --- include/Dialect/CGGI/IR/CGGIAttributes.td | 12 ++-- lib/Dialect/CGGI/IR/CGGIOps.cpp | 55 +++++++++++-------- .../CGGI/Transforms/SetDefaultParameters.cpp | 27 ++++++--- lib/Transforms/ValidateNoise/BUILD | 1 + .../ValidateNoise/ValidateNoise.cpp | 28 +++++++--- tests/validate_noise/validate_noise.mlir | 13 +++++ 6 files changed, 92 insertions(+), 44 deletions(-) diff --git a/include/Dialect/CGGI/IR/CGGIAttributes.td b/include/Dialect/CGGI/IR/CGGIAttributes.td index 99e124b29..612db4e43 100644 --- a/include/Dialect/CGGI/IR/CGGIAttributes.td +++ b/include/Dialect/CGGI/IR/CGGIAttributes.td @@ -11,12 +11,12 @@ def CGGI_CGGIParams : AttrDef { // to lwe dialect? let parameters = (ins "::mlir::heir::lwe::RLWEParamsAttr": $rlweParams, - "unsigned": $bsk_noise_variance, - "unsigned": $bsk_gadget_base_log, - "unsigned": $bsk_gadget_num_levels, - "unsigned": $ksk_noise_variance, - "unsigned": $ksk_gadget_base_log, - "unsigned": $ksk_gadget_num_levels + "int64_t": $bsk_noise_variance, + "int64_t": $bsk_gadget_base_log, + "int64_t": $bsk_gadget_num_levels, + "int64_t": $ksk_noise_variance, + "int64_t": $ksk_gadget_base_log, + "int64_t": $ksk_gadget_num_levels ); let assemblyFormat = "`<` struct(params) `>`"; diff --git a/lib/Dialect/CGGI/IR/CGGIOps.cpp b/lib/Dialect/CGGI/IR/CGGIOps.cpp index 29bfedf9b..9b8500580 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.cpp +++ b/lib/Dialect/CGGI/IR/CGGIOps.cpp @@ -7,6 +7,9 @@ #include "include/Dialect/CGGI/IR/CGGIAttributes.h" #include "include/Dialect/LWE/IR/LWEAttributes.h" #include "include/Interfaces/NoiseInterfaces.h" +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project + +#define DEBUG_TYPE "CGGIOps" namespace mlir { namespace heir { @@ -14,12 +17,13 @@ namespace cggi { unsigned maxPerDigitDecompositionError(unsigned baseLog, unsigned numLevels, unsigned ctBitWidth) { - // FIXME: this needs verification; I struggled to parse what was said in the - // CGGI paper, as well as the original DM paper, so I relied on my own - // analysis in https://jeremykun.com/2022/08/29/key-switching-in-lwe/ - // It aligns roughly with the error analysis in Theorem 4.1 of - // https://eprint.iacr.org/2018/421, but using a different perspective - // on the "precision" parameter t in that paper. + // TODO(https://github.com/google/heir/issues/297): This needs verification; + // I struggled to parse what was said in the CGGI paper, as well as the + // original DM paper, so I relied on my own analysis in + // https://jeremykun.com/2022/08/29/key-switching-in-lwe/ It aligns roughly + // with the error analysis in Theorem 4.1 of + // https://eprint.iacr.org/2018/421, but using a different perspective on the + // "precision" parameter t in that paper. // maxLevels is the number L such that B^L = lwe_cmod // a.k.a., L * log2(B) = cmod_bitwidth @@ -53,32 +57,39 @@ int64_t bootstrapOutputNoise(CGGIParamsAttr attr, unsigned kskNoiseVariance = attr.getKskNoiseVariance(); // Mirroring the notation in https://eprint.iacr.org/2018/421, Theorem 6.3. - unsigned logq = lweParams.getCmod().getValue().getBitWidth(); - unsigned n = lweParams.getDimension(); - unsigned k = rlweParams.getDimension(); - unsigned N = rlweParams.getPolyDegree(); - unsigned l = attr.getBskGadgetNumLevels(); + int64_t logq = lweParams.getCmod().getValue().logBase2(); + int64_t n = lweParams.getDimension(); + int64_t k = rlweParams.getDimension(); + int64_t N = rlweParams.getPolyDegree(); + int64_t l = attr.getBskGadgetNumLevels(); // Beta is the max absolute value of a digit of the signed decomposition - unsigned beta = (1 << attr.getBskGadgetBaseLog()) / 2; + int64_t beta = (1 << attr.getBskGadgetBaseLog()) / 2; // Epsilon is the max per-digit error of the approximation introduced by // having fewer levels in the gadget key. - // FIXME: this needs verification. I think it's the same sort of error as the - // key switching key sampleApproxError below. - unsigned epsilon = maxPerDigitDecompositionError( + // TODO(https://github.com/google/heir/issues/297): This needs verification. + // I think it's the same sort of error as the key switching key + // sampleApproxError. + int64_t epsilon = maxPerDigitDecompositionError( attr.getBskGadgetBaseLog(), attr.getBskGadgetNumLevels(), logq); - unsigned externalProductTerm = - (n * (k + 1) * l * N * beta * beta * bskNoiseVariance + - n * (1 + k * N) * epsilon * epsilon); + + int64_t blindRotateTerm1 = + n * (k + 1) * l * N * beta * beta * bskNoiseVariance; + int64_t blindRotateTerm2 = n * (1 + k * N) * epsilon * epsilon; + int64_t blindRotateTerm = blindRotateTerm1 + blindRotateTerm2; // largestDigit depends on a signed decomposition. - unsigned largestDigit = (1 << attr.getKskGadgetBaseLog()) / 2; - unsigned kskSampleApproxError = maxPerDigitDecompositionError( + int64_t largestDigit = (1 << attr.getKskGadgetBaseLog()) / 2; + int64_t kskSampleApproxError = maxPerDigitDecompositionError( attr.getKskGadgetBaseLog(), attr.getKskGadgetNumLevels(), logq); - unsigned keySwitchingTerm = + int64_t keySwitchingTerm = (attr.getKskGadgetNumLevels() * largestDigit * kskNoiseVariance + n * kskSampleApproxError); - return externalProductTerm + keySwitchingTerm; + + LLVM_DEBUG(llvm::dbgs() << "blindRotateTerm1: " << blindRotateTerm1 << "\n"); + LLVM_DEBUG(llvm::dbgs() << "blindRotateTerm2: " << blindRotateTerm2 << "\n"); + LLVM_DEBUG(llvm::dbgs() << "keySwitchingTerm: " << keySwitchingTerm << "\n"); + return blindRotateTerm + keySwitchingTerm; } void handleSingleResultOp(Operation *op, Value ctValue, diff --git a/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp b/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp index 90902bd13..45a5d95d7 100644 --- a/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp +++ b/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp @@ -27,13 +27,26 @@ struct SetDefaultParameters IntegerAttr defaultCmodAttr = IntegerAttr::get(IntegerType::get(&context, 64), defaultCmod); - // https://github.com/google/jaxite/blob/main/jaxite/jaxite_bool/bool_params.py - unsigned defaultBskNoiseVariance = 65536; // stdev = 2**8, var = 2**16 - unsigned defaultBskGadgetBaseLog = 4; - unsigned defaultBskGadgetNumLevels = 6; - unsigned defaultKskNoiseVariance = 268435456; // stdev = 2**14, var = 2**28 - unsigned defaultKskGadgetBaseLog = 4; - unsigned defaultKskGadgetNumLevels = 5; + // TODO(https://github.com/google/heir/issues/297): This needs fixing. I + // tried setting these parameters to the same values from + // https://github.com/google/jaxite/blob/main/jaxite/jaxite_bool/bool_params.py, + // but the formula for the bootstrap and key switch noises in CGGIOps.cpp + // both exceeds 30 bits so the verification fails trivially. I wonder if + // that bound is tighter in some follow-up papers? + // + // For now, setting to much smaller values so that we can get the noise + // propagation infrastructure checked in, and leaving the noise model fix + // to the linked issue. + // + // int64_t defaultBskNoiseVariance = 65536; // stdev = 2**8, var = 2**16 + // int64_t defaultKskNoiseVariance = 268435456; // stdev = 2**14, var = + // 2**28 + int64_t defaultBskNoiseVariance = 2; + int64_t defaultBskGadgetBaseLog = 2; + int64_t defaultBskGadgetNumLevels = 16; + int64_t defaultKskNoiseVariance = 1048576; // stdev = 2**10, var = 2**20 + int64_t defaultKskGadgetBaseLog = 4; + int64_t defaultKskGadgetNumLevels = 5; lwe::RLWEParamsAttr defaultRlweParams = lwe::RLWEParamsAttr::get( &context, defaultCmodAttr, defaultRlweDimension, defaultPolyDegree); diff --git a/lib/Transforms/ValidateNoise/BUILD b/lib/Transforms/ValidateNoise/BUILD index 171f22f5c..1f39a22c7 100644 --- a/lib/Transforms/ValidateNoise/BUILD +++ b/lib/Transforms/ValidateNoise/BUILD @@ -13,6 +13,7 @@ cc_library( "@heir//include/Transforms/ValidateNoise:pass_inc_gen", "@heir//lib/Analysis/NoisePropagation:NoisePropagationAnalysis", "@heir//lib/Analysis/NoisePropagation:Variance", + "@heir//lib/Dialect/LWE/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:Pass", diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp index ae2a03bb4..6377ff146 100644 --- a/lib/Transforms/ValidateNoise/ValidateNoise.cpp +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -2,6 +2,7 @@ #include "include/Analysis/NoisePropagation/NoisePropagationAnalysis.h" #include "include/Analysis/NoisePropagation/Variance.h" +#include "include/Dialect/LWE/IR/LWETypes.h" #include "include/Interfaces/NoiseInterfaces.h" #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project// from @llvm-project @@ -18,6 +19,16 @@ namespace heir { #define GEN_PASS_DEF_VALIDATENOISE #include "include/Transforms/ValidateNoise/ValidateNoise.h.inc" +// TODO(https://github.com/google/heir/issues/297): fix likely +// mistakes in the maximum formula +int64_t maxLweNoise(lwe::LWECiphertextType type) { + auto encoding = type.getEncoding().cast(); + // The cleartext start is the lowest bit of the plaintext space that contains + // the message. One less is the highest bit that contains noise. + int64_t max = 1 << (encoding.getCleartextStart() - 1); + return max; +} + struct ValidateNoise : impl::ValidateNoiseBase { using ValidateNoiseBase::ValidateNoiseBase; @@ -70,14 +81,12 @@ struct ValidateNoise : impl::ValidateNoiseBase { user->emitOpError() << "uses SSA value with unknown noise variance, but the op " "has non-constant noise propagation. This can happen " - "when " - "an SSA value is part of control flow, such as a loop " - "or " - "an entrypoint to a function with multiple callers. In " - "such cases, an extra bootstrap is required to ensure " - "the " - "value does not exceed its noise bound, or the control " - "flow must be removed. SSA value was: \n\n" + "when an SSA value is part of control flow, such as a " + "loop or an entrypoint to a function with multiple " + "callers. In such cases, an extra bootstrap is required " + "to ensure the value does not exceed its noise bound, " + "or the control flow must be removed. SSA value was: " + "\n\n" << result << "\n\n"; return WalkResult::interrupt(); } @@ -88,7 +97,8 @@ struct ValidateNoise : impl::ValidateNoiseBase { } int64_t var = opRange->getValue().getValue(); - int64_t maxNoise = 0; // FIXME: infer from the parameters? + int64_t maxNoise = + maxLweNoise(result.getType().cast()); if (var > maxNoise) { op->emitOpError() << "Found op after which the noise exceeds the " "allowable maximum of " diff --git a/tests/validate_noise/validate_noise.mlir b/tests/validate_noise/validate_noise.mlir index 9064c473b..912739ce3 100644 --- a/tests/validate_noise/validate_noise.mlir +++ b/tests/validate_noise/validate_noise.mlir @@ -19,3 +19,16 @@ func.func @test_defaults_are_valid_for_single_add() -> !ciphertext { %6 = lwe.add %4, %5 : !ciphertext return %6 : !ciphertext } + +// CHECK-LABEL: @test_boostrap_unknown_noise_input +func.func @test_boostrap_unknown_noise_input(%0 : !ciphertext) -> !ciphertext { + %1 = cggi.lut2(%0, %0) {lookup_table = 1 : ui4} : !ciphertext + return %1 : !ciphertext +} + +// CHECK-LABEL: @test_add_post_bootstrap +func.func @test_add_post_bootstrap(%0 : !ciphertext) -> !ciphertext { + %1 = cggi.lut2(%0, %0) {lookup_table = 1 : ui4} : !ciphertext + %2 = lwe.add %1, %1 : !ciphertext + return %2 : !ciphertext +} From 8b47076664d329669782a0266a3012b6788ebced Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 29 Nov 2023 19:13:36 -0800 Subject: [PATCH 19/20] squashme: wip --- include/Analysis/NoisePropagation/Variance.h | 15 +++++-- .../NoisePropagationAnalysis.cpp | 12 +++--- .../CGGI/Transforms/SetDefaultParameters.cpp | 1 - .../ValidateNoise/ValidateNoise.cpp | 8 +++- .../validate_noise/validate_noise_errors.mlir | 39 +++++++++++++++++++ 5 files changed, 64 insertions(+), 11 deletions(-) create mode 100644 tests/validate_noise/validate_noise_errors.mlir diff --git a/include/Analysis/NoisePropagation/Variance.h b/include/Analysis/NoisePropagation/Variance.h index cb382e073..720778a63 100644 --- a/include/Analysis/NoisePropagation/Variance.h +++ b/include/Analysis/NoisePropagation/Variance.h @@ -12,6 +12,15 @@ namespace mlir { namespace heir { +enum VarianceType { + UNSET, // A min value for the lattice, i.e., discarable when joined with + // anything else. + KNOWN, // A known value for the lattice, i.e., when noise can be inferred. + INDEPENDENT, // A known value for the lattice, independent of + MAX // A max value for the lattice, i.e., when noise cannot be inferred and a + // bootstrap must be forced. +}; + /// A class representing an optional variance of a noise distribution. class Variance { public: @@ -31,10 +40,10 @@ class Variance { /// This method represents how to choose a noise from one of two possible /// branches, when either could be possible. In the case of FHE, we must - /// assume the worse case, so take the max. + /// assume the worst case. If either is unknown, assume unknown, otherwise + /// take the max. static Variance join(const Variance &lhs, const Variance &rhs) { - if (!lhs.isKnown()) return rhs; - if (!rhs.isKnown()) return lhs; + if (!lhs.isKnown() || !rhs.isKnown()) return Variance::unknown(); return Variance{std::max(lhs.getValue(), rhs.getValue())}; } diff --git a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp index f302fdac5..05213258f 100644 --- a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp +++ b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp @@ -27,8 +27,8 @@ void NoisePropagationAnalysis::visitOperation( })) { LLVM_DEBUG(llvm::dbgs() << "Op " << noisePropagationOp->getName() - << "with argument-dependent noise propagation encountered input " - "with unknown noise. Marking result noise as unknown.\n"); + << " with argument-dependent noise propagation encountered input" + " with unknown noise. Marking result noise as unknown.\n"); return setAllToEntryStates(results); } @@ -45,6 +45,8 @@ void NoisePropagationAnalysis::visitOperation( Variance oldRange = lattice->getValue(); ChangeResult changed = lattice->join(Variance{variance}); + // FIXME: does this even make sense as a lattice?? + // // If the result is yielded, then the best we can do is check to see if the // op producing this value has argument-independent noise. If so, we can // propagate that noise. Otherwise, we must assume the worst case scenario @@ -58,15 +60,15 @@ void NoisePropagationAnalysis::visitOperation( // determine where in the codebase one should look for stuff related to // this method. if (isYieldedResult && oldRange.isKnown() && - !(lattice->getValue() == oldRange) && + !(lattice.getValue() == oldRange) && !noisePropagationOp.hasArgumentIndependentResultNoise()) { LLVM_DEBUG( llvm::dbgs() << "Non-constant noise-propagating op passed to a region " "terminator. Assuming loop result and marking noise unknown\n"); - changed |= lattice->join(Variance(std::nullopt)); + changed |= lattice.join(Variance::unknown()); } - propagateIfChanged(lattice, changed); + propagateIfChanged(&lattice, changed); }; noisePropagationOp.inferResultNoise(argRanges, joinCallback); diff --git a/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp b/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp index 45a5d95d7..86c988e63 100644 --- a/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp +++ b/lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp @@ -4,7 +4,6 @@ #include "include/Dialect/CGGI/IR/CGGIOps.h" #include "include/Dialect/LWE/IR/LWEAttributes.h" #include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project namespace mlir { diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp index 6377ff146..a9de5cadf 100644 --- a/lib/Transforms/ValidateNoise/ValidateNoise.cpp +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -5,7 +5,8 @@ #include "include/Dialect/LWE/IR/LWETypes.h" #include "include/Interfaces/NoiseInterfaces.h" #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project// from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project +#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" // from @llvm-projectject #include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project @@ -36,8 +37,11 @@ struct ValidateNoise : impl::ValidateNoiseBase { auto *module = getOperation(); DataFlowSolver solver; - // The dataflow solver needs DeadCodeAnalysis to run the other analyses + // The dataflow solver needs DeadCodeAnalysis and SparseConstantPropagation + // to run pretty much any data flow analysis, see + // https://discourse.llvm.org/t/mlir-dead-code-analysis/67568/8 solver.load(); + solver.load(); solver.load(); if (failed(solver.initializeAndRun(module))) { getOperation()->emitOpError() << "Failed to run the analysis.\n"; diff --git a/tests/validate_noise/validate_noise_errors.mlir b/tests/validate_noise/validate_noise_errors.mlir new file mode 100644 index 000000000..3ab773961 --- /dev/null +++ b/tests/validate_noise/validate_noise_errors.mlir @@ -0,0 +1,39 @@ +// RUN: heir-opt --split-input-file --cggi-set-default-parameters --lwe-set-default-parameters --validate-noise --verify-diagnostics %s + +// TODO(https://github.com/google/heir/issues/296): use lwe.encrypt with +// realistic initial noise. + +// #encoding = #lwe.bit_field_encoding +// #poly = #polynomial.polynomial<1 + x**1024> +// !plaintext = !lwe.lwe_plaintext +// !ciphertext = !lwe.lwe_ciphertext +// +// func.func @test_cant_add_unknown_value(%arg0 : !ciphertext) -> !ciphertext { +// // expected-error@below {{uses SSA value with unknown noise variance}} +// %1 = lwe.add %arg0, %arg0 : !ciphertext +// return %1 : !ciphertext +// } +// +// // ----- + +#encoding = #lwe.bit_field_encoding +#poly = #polynomial.polynomial<1 + x**1024> +!plaintext = !lwe.lwe_plaintext +!ciphertext = !lwe.lwe_ciphertext + +func.func @unknown_value_from_loop_result() -> !ciphertext { + %0 = arith.constant 0 : i1 + %2 = lwe.encode %0 { encoding = #encoding }: i1 to !plaintext + %3 = lwe.trivial_encrypt %2 : !plaintext to !ciphertext + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c5 = arith.constant 5 : index + + %5 = scf.for %arg1 = %c1 to %c5 step %c1 iter_args(%iter_arg = %3) -> !ciphertext { + // expected-error@below {{uses SSA value with unknown noise variance}} + %6 = lwe.add %iter_arg, %iter_arg : !ciphertext + scf.yield %6 : !ciphertext + } + return %5 : !ciphertext +} From 23610892de4f00fea6c0f31aa4fba2ed2dedfa92 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 30 Nov 2023 16:05:11 -0800 Subject: [PATCH 20/20] rewrite lattice with distinct uninitialized and unbounded states --- .../NoisePropagationAnalysis.h | 2 +- include/Analysis/NoisePropagation/Variance.h | 71 ++++++++++++++----- .../NoisePropagationAnalysis.cpp | 42 ++++++----- lib/Analysis/NoisePropagation/Variance.cpp | 17 +++-- .../MemrefToArith/UnrollAndForward.cpp | 2 +- lib/Dialect/CGGI/IR/CGGIOps.cpp | 2 +- lib/Dialect/LWE/IR/LWEDialect.cpp | 18 ++--- .../ValidateNoise/ValidateNoise.cpp | 20 ++++-- tests/validate_noise/validate_noise.mlir | 19 +++++ .../validate_noise/validate_noise_errors.mlir | 55 +++++++------- 10 files changed, 166 insertions(+), 82 deletions(-) diff --git a/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h b/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h index 0abfc2d86..aff7600cb 100644 --- a/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h +++ b/include/Analysis/NoisePropagation/NoisePropagationAnalysis.h @@ -28,7 +28,7 @@ class NoisePropagationAnalysis void setToEntryState(VarianceLattice *lattice) override { // At an entry point, we have no information about the noise. - propagateIfChanged(lattice, lattice->join(Variance(std::nullopt))); + propagateIfChanged(lattice, lattice->join(Variance::uninitialized())); } void visitOperation(Operation *op, ArrayRef operands, diff --git a/include/Analysis/NoisePropagation/Variance.h b/include/Analysis/NoisePropagation/Variance.h index 720778a63..48a7b7a10 100644 --- a/include/Analysis/NoisePropagation/Variance.h +++ b/include/Analysis/NoisePropagation/Variance.h @@ -13,42 +13,80 @@ namespace mlir { namespace heir { enum VarianceType { - UNSET, // A min value for the lattice, i.e., discarable when joined with - // anything else. - KNOWN, // A known value for the lattice, i.e., when noise can be inferred. - INDEPENDENT, // A known value for the lattice, independent of - MAX // A max value for the lattice, i.e., when noise cannot be inferred and a - // bootstrap must be forced. + // A min value for the lattice, discarable when joined with anything else. + UNINITIALIZED, + // A known value for the lattice, when noise can be inferred. + SET, + // A max value for the lattice, when noise cannot be inferred and a bootstrap + // must be forced. + UNBOUNDED }; /// A class representing an optional variance of a noise distribution. class Variance { public: - static Variance unknown() { return Variance(); } + static Variance uninitialized() { + return Variance(VarianceType::UNINITIALIZED, std::nullopt); + } + static Variance unbounded() { + return Variance(VarianceType::UNBOUNDED, std::nullopt); + } + static Variance of(int64_t value) { + return Variance(VarianceType::SET, value); + } /// Create an integer value range lattice value. - Variance(std::optional value = std::nullopt) : value(value) {} + /// The default constructor must be equivalent to the "entry state" of the + /// lattice, i.e., an uninitialized noise variance. + Variance(VarianceType varianceType = VarianceType::UNINITIALIZED, + std::optional value = std::nullopt) + : varianceType(varianceType), value(value) {} + + bool isKnown() const { return varianceType == VarianceType::SET; } - bool isKnown() const { return value.has_value(); } + bool isInitialized() const { + return varianceType != VarianceType::UNINITIALIZED; + } + + bool isBounded() const { return varianceType != VarianceType::UNBOUNDED; } const int64_t &getValue() const { assert(isKnown()); return *value; } - bool operator==(const Variance &rhs) const { return value == rhs.value; } + bool operator==(const Variance &rhs) const { + return varianceType == rhs.varianceType && value == rhs.value; + } - /// This method represents how to choose a noise from one of two possible - /// branches, when either could be possible. In the case of FHE, we must - /// assume the worst case. If either is unknown, assume unknown, otherwise - /// take the max. static Variance join(const Variance &lhs, const Variance &rhs) { - if (!lhs.isKnown() || !rhs.isKnown()) return Variance::unknown(); - return Variance{std::max(lhs.getValue(), rhs.getValue())}; + // Uninitialized variances correspond to values that are not secret, + // which may be the inputs to an encryption operation. + if (lhs.varianceType == VarianceType::UNINITIALIZED) { + return rhs; + } + if (rhs.varianceType == VarianceType::UNINITIALIZED) { + return lhs; + } + + // Unbounded represents a pessimistic worst case, and so it must be + // preserved no matter the other operand. + if (lhs.varianceType == VarianceType::UNBOUNDED) { + return lhs; + } + if (rhs.varianceType == VarianceType::UNBOUNDED) { + return rhs; + } + + assert(lhs.varianceType == VarianceType::SET && + rhs.varianceType == VarianceType::SET); + return Variance::of(std::max(lhs.getValue(), rhs.getValue())); } void print(llvm::raw_ostream &os) const { os << value; } + std::string toString() const; + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Variance &variance); @@ -56,6 +94,7 @@ class Variance { const Variance &variance); private: + VarianceType varianceType; std::optional value; }; diff --git a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp index 05213258f..c17399d54 100644 --- a/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp +++ b/lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp @@ -19,21 +19,29 @@ void NoisePropagationAnalysis::visitOperation( LLVM_DEBUG(llvm::dbgs() << "Propagating noise for " << noisePropagationOp << "\n"); - // Ops with argument-independent noise propagation can work with unknown noise - // arguments, but others cannot. + SmallVector argRanges(llvm::map_range( + operands, [](const VarianceLattice *val) { return val->getValue(); })); + + // Ops with argument-independent noise propagation can work with unbounded + // and uninitialized noise arguments, but others cannot. If we encounter + // a situation where an argument-noise-dependent op processes unknown args, + // we set a pessimistic unbounded variance on the results. if (!noisePropagationOp.hasArgumentIndependentResultNoise() && - llvm::any_of(operands, [](const VarianceLattice *lattice) { - return !lattice->getValue().isKnown(); + llvm::any_of(argRanges, [](const Variance variance) { + return !variance.isKnown(); })) { LLVM_DEBUG(llvm::dbgs() << "Op " << noisePropagationOp->getName() << " with argument-dependent noise propagation encountered input" " with unknown noise. Marking result noise as unknown.\n"); - return setAllToEntryStates(results); - } - SmallVector argRanges(llvm::map_range( - operands, [](const VarianceLattice *val) { return val->getValue(); })); + for (auto result : op->getResults()) { + VarianceLattice *lattice = results[result.getResultNumber()]; + ChangeResult changed = lattice->join(Variance::unbounded()); + propagateIfChanged(lattice, changed); + } + return; + } auto joinCallback = [&](Value value, const Variance &variance) { auto result = dyn_cast(value); @@ -43,14 +51,12 @@ void NoisePropagationAnalysis::visitOperation( LLVM_DEBUG(llvm::dbgs() << "Inferred noise " << variance << "\n"); VarianceLattice *lattice = results[result.getResultNumber()]; Variance oldRange = lattice->getValue(); - ChangeResult changed = lattice->join(Variance{variance}); + ChangeResult changed = lattice->join(variance); - // FIXME: does this even make sense as a lattice?? - // - // If the result is yielded, then the best we can do is check to see if the - // op producing this value has argument-independent noise. If so, we can - // propagate that noise. Otherwise, we must assume the worst case scenario - // of unknown noise. + // If the result is yielded, then it is assumed to be in a loop yield. Then + // the best we can do is check to see if the op producing this value has + // argument-independent noise. If so, we can propagate that noise. + // Otherwise, we must assume the worst case scenario of unbounded noise. bool isYieldedResult = llvm::any_of(value.getUsers(), [](Operation *op) { return op->hasTrait(); }); @@ -60,15 +66,15 @@ void NoisePropagationAnalysis::visitOperation( // determine where in the codebase one should look for stuff related to // this method. if (isYieldedResult && oldRange.isKnown() && - !(lattice.getValue() == oldRange) && + !(lattice->getValue() == oldRange) && !noisePropagationOp.hasArgumentIndependentResultNoise()) { LLVM_DEBUG( llvm::dbgs() << "Non-constant noise-propagating op passed to a region " "terminator. Assuming loop result and marking noise unknown\n"); - changed |= lattice.join(Variance::unknown()); + changed |= lattice->join(Variance::unbounded()); } - propagateIfChanged(&lattice, changed); + propagateIfChanged(lattice, changed); }; noisePropagationOp.inferResultNoise(argRanges, joinCallback); diff --git a/lib/Analysis/NoisePropagation/Variance.cpp b/lib/Analysis/NoisePropagation/Variance.cpp index f065f2530..5f30422f1 100644 --- a/lib/Analysis/NoisePropagation/Variance.cpp +++ b/lib/Analysis/NoisePropagation/Variance.cpp @@ -5,14 +5,23 @@ namespace mlir { namespace heir { +std::string Variance::toString() const { + switch (varianceType) { + case (VarianceType::UNINITIALIZED): + return "Variance(uninitialized)"; + case (VarianceType::UNBOUNDED): + return "Variance(unbounded)"; + case (VarianceType::SET): + return "Variance(" + std::to_string(getValue()) + ")"; + } +} + llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Variance &variance) { - if (!variance.isKnown()) return os << "unknown"; - return os << variance.getValue(); + return os << variance.toString(); } Diagnostic &operator<<(Diagnostic &diagnostic, const Variance &variance) { - if (!variance.isKnown()) return diagnostic << "unknown"; - return diagnostic << variance.getValue(); + return diagnostic << variance.toString(); } } // namespace heir diff --git a/lib/Conversion/MemrefToArith/UnrollAndForward.cpp b/lib/Conversion/MemrefToArith/UnrollAndForward.cpp index da8dc5f45..185f7c519 100644 --- a/lib/Conversion/MemrefToArith/UnrollAndForward.cpp +++ b/lib/Conversion/MemrefToArith/UnrollAndForward.cpp @@ -357,7 +357,7 @@ struct UnrollAndForwardPass mlir::arith::ArithDialect, mlir::scf::SCFDialect>(); } - void runOnOperation(); + void runOnOperation() override; StringRef getArgument() const final { return "unroll-and-forward"; } diff --git a/lib/Dialect/CGGI/IR/CGGIOps.cpp b/lib/Dialect/CGGI/IR/CGGIOps.cpp index 9b8500580..0064f9739 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.cpp +++ b/lib/Dialect/CGGI/IR/CGGIOps.cpp @@ -109,7 +109,7 @@ void handleSingleResultOp(Operation *op, Value ctValue, } auto cggiParams = llvm::cast(attrs.get("cggi_params")); setValueNoise(op->getResult(0), - Variance(bootstrapOutputNoise(cggiParams, lweParams))); + Variance::of(bootstrapOutputNoise(cggiParams, lweParams))); } void AndOp::inferResultNoise(llvm::ArrayRef argNoises, diff --git a/lib/Dialect/LWE/IR/LWEDialect.cpp b/lib/Dialect/LWE/IR/LWEDialect.cpp index edb967fb9..dc890fd91 100644 --- a/lib/Dialect/LWE/IR/LWEDialect.cpp +++ b/lib/Dialect/LWE/IR/LWEDialect.cpp @@ -184,18 +184,20 @@ LogicalResult TrivialEncryptOp::verify() { void AddOp::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { - Variance result; - if (!argNoises[0].isKnown() || !argNoises[1].isKnown()) - result = Variance::unknown(); - else - result = Variance{argNoises[0].getValue() + argNoises[1].getValue()}; - - return setValueNoise(getResult(), result); + if (!argNoises[0].isInitialized() || !argNoises[1].isInitialized()) { + emitOpError() << "uses SSA value with uninitialized noise variance."; + return setValueNoise(getResult(), Variance::unbounded()); + } + return setValueNoise( + getResult(), + (argNoises[0].isBounded() && argNoises[1].isBounded()) + ? Variance::of(argNoises[0].getValue() + argNoises[1].getValue()) + : Variance::unbounded()); } void TrivialEncryptOp::inferResultNoise(llvm::ArrayRef argNoises, SetNoiseFn setValueNoise) { - return setValueNoise(getResult(), Variance(0)); + return setValueNoise(getResult(), Variance::of(0)); } bool AddOp::hasArgumentIndependentResultNoise() { return false; } diff --git a/lib/Transforms/ValidateNoise/ValidateNoise.cpp b/lib/Transforms/ValidateNoise/ValidateNoise.cpp index a9de5cadf..23859fe64 100644 --- a/lib/Transforms/ValidateNoise/ValidateNoise.cpp +++ b/lib/Transforms/ValidateNoise/ValidateNoise.cpp @@ -60,16 +60,23 @@ struct ValidateNoise : impl::ValidateNoiseBase { "all."); return WalkResult::interrupt(); } + LLVM_DEBUG(op->emitRemark() << "Found noise " << (opRange->getValue()) << " for op result " << result.getResultNumber()); - // It's OK for some places to not know the noise, so long as the only + if (!opRange->getValue().isInitialized()) { + LLVM_DEBUG(llvm::dbgs() + << "Skipping check due to uninitialized noise.\n"); + return WalkResult::advance(); + } + + // It's OK for some places to have unbounded noise, so long as the only // user of that value is a bootstrap-like op. - if (!opRange->getValue().isKnown()) { + if (!opRange->getValue().isBounded()) { // One might expect a check for hasSingleUse, but there could - // potentially be multiple downstream users, each applying a different - // kind of programmable bootstrap to compute different functions, so - // we loop over all users. + // potentially be multiple downstream users, each applying a + // different kind of programmable bootstrap to compute different + // functions, so we loop over all users. for (auto result : op->getResults()) { for (Operation *user : result.getUsers()) { auto noisePropagationOp = @@ -83,7 +90,8 @@ struct ValidateNoise : impl::ValidateNoiseBase { if (noisePropagationOp && !noisePropagationOp.hasArgumentIndependentResultNoise()) { user->emitOpError() - << "uses SSA value with unknown noise variance, but the op " + << "uses SSA value with unbounded noise variance, but the " + "op " "has non-constant noise propagation. This can happen " "when an SSA value is part of control flow, such as a " "loop or an entrypoint to a function with multiple " diff --git a/tests/validate_noise/validate_noise.mlir b/tests/validate_noise/validate_noise.mlir index 912739ce3..32bdbf772 100644 --- a/tests/validate_noise/validate_noise.mlir +++ b/tests/validate_noise/validate_noise.mlir @@ -32,3 +32,22 @@ func.func @test_add_post_bootstrap(%0 : !ciphertext) -> !ciphertext { %2 = lwe.add %1, %1 : !ciphertext return %2 : !ciphertext } + +// CHECK-LABEL: @test_loop_result_with_zero_noise +func.func @test_loop_result_with_zero_noise() -> !ciphertext { + %0 = arith.constant 0 : i1 + %2 = lwe.encode %0 { encoding = #encoding }: i1 to !plaintext + %3 = lwe.trivial_encrypt %2 : !plaintext to !ciphertext + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c5 = arith.constant 5 : index + + // This is OK because the input noise is zero and stays zero through each + // iteration, so the data flow solver reaches a fixed point. + %5 = scf.for %arg1 = %c1 to %c5 step %c1 iter_args(%iter_arg = %3) -> !ciphertext { + %6 = lwe.add %iter_arg, %iter_arg : !ciphertext + scf.yield %6 : !ciphertext + } + return %5 : !ciphertext +} diff --git a/tests/validate_noise/validate_noise_errors.mlir b/tests/validate_noise/validate_noise_errors.mlir index 3ab773961..0964d699e 100644 --- a/tests/validate_noise/validate_noise_errors.mlir +++ b/tests/validate_noise/validate_noise_errors.mlir @@ -3,37 +3,38 @@ // TODO(https://github.com/google/heir/issues/296): use lwe.encrypt with // realistic initial noise. -// #encoding = #lwe.bit_field_encoding -// #poly = #polynomial.polynomial<1 + x**1024> -// !plaintext = !lwe.lwe_plaintext -// !ciphertext = !lwe.lwe_ciphertext -// -// func.func @test_cant_add_unknown_value(%arg0 : !ciphertext) -> !ciphertext { -// // expected-error@below {{uses SSA value with unknown noise variance}} -// %1 = lwe.add %arg0, %arg0 : !ciphertext -// return %1 : !ciphertext -// } -// -// // ----- - #encoding = #lwe.bit_field_encoding #poly = #polynomial.polynomial<1 + x**1024> !plaintext = !lwe.lwe_plaintext !ciphertext = !lwe.lwe_ciphertext -func.func @unknown_value_from_loop_result() -> !ciphertext { - %0 = arith.constant 0 : i1 - %2 = lwe.encode %0 { encoding = #encoding }: i1 to !plaintext - %3 = lwe.trivial_encrypt %2 : !plaintext to !ciphertext +func.func @test_cant_add_unknown_value(%arg0 : !ciphertext) -> !ciphertext { + // expected-error@below {{uses SSA value with uninitialized noise variance}} + %1 = lwe.add %arg0, %arg0 : !ciphertext + return %1 : !ciphertext +} - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c5 = arith.constant 5 : index +// ----- - %5 = scf.for %arg1 = %c1 to %c5 step %c1 iter_args(%iter_arg = %3) -> !ciphertext { - // expected-error@below {{uses SSA value with unknown noise variance}} - %6 = lwe.add %iter_arg, %iter_arg : !ciphertext - scf.yield %6 : !ciphertext - } - return %5 : !ciphertext -} +// #encoding = #lwe.bit_field_encoding +// #poly = #polynomial.polynomial<1 + x**1024> +// !plaintext = !lwe.lwe_plaintext +// !ciphertext = !lwe.lwe_ciphertext +// +// func.func @unknown_value_from_loop_result() -> !ciphertext { +// %0 = arith.constant 0 : i1 +// %1 = lwe.encode %0 { encoding = #encoding }: i1 to !plaintext +// %2 = lwe.trivial_encrypt %1 : !plaintext to !ciphertext +// %3 = cggi.and %2, %2 : !ciphertext +// +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %c5 = arith.constant 5 : index +// +// %5 = scf.for %arg1 = %c1 to %c5 step %c1 iter_args(%iter_arg = %3) -> !ciphertext { +// // expected-error@below {{uses SSA value with unbounded noise variance}} +// %6 = lwe.add %iter_arg, %iter_arg : !ciphertext +// scf.yield %6 : !ciphertext +// } +// return %5 : !ciphertext +// }