From 9afa3e8bb13971c14aec666c6463e5e55048e76e Mon Sep 17 00:00:00 2001 From: Zenithal Date: Fri, 24 Jan 2025 07:50:03 +0000 Subject: [PATCH] Add external debug port during execution --- .../Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp | 82 ++++++++++-- lib/Dialect/LWE/Transforms/AddDebugPort.cpp | 118 ++++++++++++++++++ lib/Dialect/LWE/Transforms/AddDebugPort.h | 17 +++ lib/Dialect/LWE/Transforms/BUILD | 18 +++ lib/Dialect/LWE/Transforms/Passes.h | 1 + lib/Dialect/LWE/Transforms/Passes.td | 33 +++++ .../ArithmeticPipelineRegistration.cpp | 8 ++ .../ArithmeticPipelineRegistration.h | 5 + lib/Pipelines/BUILD | 1 + lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp | 67 ++++++++-- lib/Target/OpenFhePke/OpenFhePkeEmitter.h | 4 + tests/Examples/openfhe/BUILD | 9 ++ .../openfhe/dot_product_8_debug_test.cpp | 47 +++++++ 13 files changed, 390 insertions(+), 20 deletions(-) create mode 100644 lib/Dialect/LWE/Transforms/AddDebugPort.cpp create mode 100644 lib/Dialect/LWE/Transforms/AddDebugPort.h create mode 100644 tests/Examples/openfhe/dot_product_8_debug_test.cpp diff --git a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp index fad6f933d8..96352659e7 100644 --- a/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp @@ -59,6 +59,15 @@ FailureOr getContextualCryptoContext(Operation *op) { return result.value(); } +// NOTE: we can not use containsDialect +// for FuncOp declaration, which does not have a body +template +bool containsArgumentOfDialect(func::FuncOp funcOp) { + return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) { + return DialectEqual()(&argType.getDialect()); + }); +} + struct AddCryptoContextArg : public OpConversionPattern { AddCryptoContextArg(mlir::MLIRContext *context) : OpConversionPattern(context, /* benefit= */ 2) {} @@ -68,8 +77,13 @@ struct AddCryptoContextArg : public OpConversionPattern { LogicalResult matchAndRewrite( func::FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!containsDialects( - op)) { + auto containsCryptoOps = + containsDialects( + op); + auto containsCryptoArg = + containsArgumentOfDialect(op); + if (!(containsCryptoOps || containsCryptoArg)) { return failure(); } @@ -86,15 +100,47 @@ struct AddCryptoContextArg : public OpConversionPattern { rewriter.modifyOpInPlace(op, [&] { op.setType(newFuncType); - Block &block = op.getBody().getBlocks().front(); - block.insertArgument(&block.getArguments().front(), cryptoContextType, - op.getLoc()); + // guard against private FuncOp (i.e. declaration) + if (op.getVisibility() != SymbolTable::Visibility::Private) { + Block &block = op.getBody().getBlocks().front(); + block.insertArgument(&block.getArguments().front(), cryptoContextType, + op.getLoc()); + } }); return success(); } }; +struct ConvertFuncCallOp : public OpConversionPattern { + ConvertFuncCallOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + func::CallOp op, typename func::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + FailureOr result = getContextualCryptoContext(op.getOperation()); + if (failed(result)) return result; + auto cryptoContext = result.value(); + + auto callee = op.getCallee(); + auto operands = adaptor.getOperands(); + auto resultTypes = op.getResultTypes(); + + SmallVector newOperands; + newOperands.push_back(cryptoContext); + for (auto operand : operands) { + newOperands.push_back(operand); + } + + rewriter.replaceOpWithNewOp(op, callee, resultTypes, + newOperands); + return success(); + } +}; + struct ConvertEncryptOp : public OpConversionPattern { ConvertEncryptOp(mlir::MLIRContext *context) : OpConversionPattern(context) {} @@ -275,11 +321,28 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase { bool hasCryptoContextArg = op.getFunctionType().getNumInputs() > 0 && mlir::isa( *op.getFunctionType().getInputs().begin()); + auto containsCryptoOps = + containsDialects( + op); + auto containsCryptoArg = + containsArgumentOfDialect(op); return typeConverter.isSignatureLegal(op.getFunctionType()) && typeConverter.isLegal(&op.getBody()) && - (!containsDialects(op) || - hasCryptoContextArg); + (!(containsCryptoOps || containsCryptoArg) || hasCryptoContextArg); + }); + + // Ensures that callee function signature is consistent + target.addDynamicallyLegalOp([&](func::CallOp op) { + auto operandTypes = op.getCalleeType().getInputs(); + auto containsCryptoArg = llvm::any_of(operandTypes, [&](Type argType) { + return DialectEqual()(&argType.getDialect()); + }); + auto hasCryptoContextArg = + !operandTypes.empty() && + mlir::isa(*operandTypes.begin()); + return (!containsCryptoArg || hasCryptoContextArg); }); patterns.add< @@ -290,6 +353,9 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase { // Update Func Op Signature AddCryptoContextArg, + // Update Func CallOp Signature + ConvertFuncCallOp, + // Handle LWE encode and en/decrypt // Note: `lwe.decode` is handled directly by the OpenFHE emitter ConvertEncodeOp, ConvertEncryptOp, ConvertDecryptOp, diff --git a/lib/Dialect/LWE/Transforms/AddDebugPort.cpp b/lib/Dialect/LWE/Transforms/AddDebugPort.cpp new file mode 100644 index 0000000000..46557fa878 --- /dev/null +++ b/lib/Dialect/LWE/Transforms/AddDebugPort.cpp @@ -0,0 +1,118 @@ +#include "lib/Dialect/LWE/Transforms/AddDebugPort.h" + +#include "lib/Dialect/LWE/IR/LWEOps.h" +#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace lwe { + +#define GEN_PASS_DEF_ADDDEBUGPORT +#include "lib/Dialect/LWE/Transforms/Passes.h.inc" + +FailureOr getPrivateKeyType(func::FuncOp op) { + const auto *type = llvm::find_if(op.getArgumentTypes(), [](Type type) { + return mlir::isa(type); + }); + + if (type == op.getArgumentTypes().end()) { + return op.emitError( + "Function does not have an argument of LWECiphertextType"); + } + + auto lweCiphertextType = cast(*type); + + auto lwePrivateKeyType = NewLWESecretKeyType::get( + op.getContext(), lweCiphertextType.getKey(), + lweCiphertextType.getCiphertextSpace().getRing()); + return lwePrivateKeyType; +} + +func::FuncOp getOrCreateExternalDebugFunc( + ModuleOp module, Type lwePrivateKeyType, + NewLWECiphertextType lweCiphertextType, + const DenseMap &typeToInt) { + std::string funcName = + "__heir_debug_" + std::to_string(typeToInt.at(lweCiphertextType)); + + auto *context = module.getContext(); + auto lookup = module.lookupSymbol(funcName); + if (lookup) return lookup; + + auto debugFuncType = + FunctionType::get(context, {lwePrivateKeyType, lweCiphertextType}, {}); + + ImplicitLocOpBuilder b = + ImplicitLocOpBuilder::atBlockBegin(module.getLoc(), module.getBody()); + auto funcOp = b.create(funcName, debugFuncType); + // required for external func call + funcOp.setPrivate(); + return funcOp; +} + +LogicalResult insertExternalCall(func::FuncOp op, Type lwePrivateKeyType) { + auto module = op->getParentOfType(); + + // map ciphertext type to unique int + DenseMap typeToInt; + + // implicit assumption the first argument is private key + auto privateKey = op.getArgument(0); + + ImplicitLocOpBuilder b = + ImplicitLocOpBuilder::atBlockBegin(module.getLoc(), module.getBody()); + op.walk([&](Operation *op) { + b.setInsertionPointAfter(op); + for (Value result : op->getResults()) { + Type resultType = result.getType(); + if (auto lweCiphertextType = dyn_cast(resultType)) { + // update typeToInt + if (!typeToInt.count(resultType)) { + typeToInt[resultType] = typeToInt.size(); + } + b.create( + getOrCreateExternalDebugFunc(module, lwePrivateKeyType, + lweCiphertextType, typeToInt), + ArrayRef{privateKey, result}); + } + } + return WalkResult::advance(); + }); + return success(); +} + +LogicalResult convertFunc(func::FuncOp op) { + auto type = getPrivateKeyType(op); + if (failed(type)) return failure(); + auto lwePrivateKeyType = type.value(); + + op.insertArgument(0, lwePrivateKeyType, nullptr, op.getLoc()); + if (failed(insertExternalCall(op, lwePrivateKeyType))) { + return failure(); + } + return success(); +} + +struct AddDebugPort : impl::AddDebugPortBase { + using AddDebugPortBase::AddDebugPortBase; + + void runOnOperation() override { + auto result = + getOperation()->walk([&](func::FuncOp op) { + if (op.getSymName() == entryFunction && failed(convertFunc(op))) { + op->emitError("Failed to add client interface for func"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) signalPassFailure(); + } +}; +} // namespace lwe +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/LWE/Transforms/AddDebugPort.h b/lib/Dialect/LWE/Transforms/AddDebugPort.h new file mode 100644 index 0000000000..860e279163 --- /dev/null +++ b/lib/Dialect/LWE/Transforms/AddDebugPort.h @@ -0,0 +1,17 @@ +#ifndef LIB_DIALECT_LWE_TRANSFORMS_ADDDEBUGPORT_H_ +#define LIB_DIALECT_LWE_TRANSFORMS_ADDDEBUGPORT_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace lwe { + +#define GEN_PASS_DECL_ADDDEBUGPORT +#include "lib/Dialect/LWE/Transforms/Passes.h.inc" + +} // namespace lwe +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_LWE_TRANSFORMS_ADDDEBUGPORT_H_ diff --git a/lib/Dialect/LWE/Transforms/BUILD b/lib/Dialect/LWE/Transforms/BUILD index b540ccaa49..59327f6e73 100644 --- a/lib/Dialect/LWE/Transforms/BUILD +++ b/lib/Dialect/LWE/Transforms/BUILD @@ -12,6 +12,7 @@ cc_library( ], deps = [ ":AddClientInterface", + ":AddDebugPort", ":SetDefaultParameters", ":pass_inc_gen", "@heir//lib/Dialect/LWE/IR:Dialect", @@ -35,6 +36,23 @@ cc_library( ], ) +cc_library( + name = "AddDebugPort", + srcs = ["AddDebugPort.cpp"], + hdrs = [ + "AddDebugPort.h", + ], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/LWE/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "SetDefaultParameters", srcs = ["SetDefaultParameters.cpp"], diff --git a/lib/Dialect/LWE/Transforms/Passes.h b/lib/Dialect/LWE/Transforms/Passes.h index 72527511d6..57e14c39ad 100644 --- a/lib/Dialect/LWE/Transforms/Passes.h +++ b/lib/Dialect/LWE/Transforms/Passes.h @@ -3,6 +3,7 @@ #include "lib/Dialect/LWE/IR/LWEDialect.h" #include "lib/Dialect/LWE/Transforms/AddClientInterface.h" +#include "lib/Dialect/LWE/Transforms/AddDebugPort.h" #include "lib/Dialect/LWE/Transforms/SetDefaultParameters.h" namespace mlir { diff --git a/lib/Dialect/LWE/Transforms/Passes.td b/lib/Dialect/LWE/Transforms/Passes.td index 51b9890337..9a6869c266 100644 --- a/lib/Dialect/LWE/Transforms/Passes.td +++ b/lib/Dialect/LWE/Transforms/Passes.td @@ -22,6 +22,39 @@ def AddClientInterface : Pass<"lwe-add-client-interface"> { ]; } +def AddDebugPort : Pass<"lwe-add-debug-port"> { + let summary = "Add debug port to (R)LWE encrypted functions"; + let description = [{ + This pass adds debug ports to the specified function in the IR. The debug ports + are prefixed with "__heir_debug" and are invoked after each homomorphic operation in the + function. The debug ports are declarations and user should provide functions with + the same name in their code. + + For example, if the function is called "foo", the secret key is added to its + arguments, and the debug port is called after each homomorphic operation: + ```mlir + // declaration of external debug function + func.func private @__heir_debug(%sk : !sk, %ct : !ct) + + // secret key added as function arg + func.func @foo(%sk : !sk, ...) { + %ct = lwe.radd ... + // invoke external debug function + __heir_debug(%sk, %ct) + %ct1 = lwe.rmul ... + __heir_debug(%sk, %ct1) + ... + } + ``` + }]; + let dependentDialects = ["mlir::heir::lwe::LWEDialect"]; + let options = [ + Option<"entryFunction", "entry-function", "std::string", + /*default=*/"", "Default entry function " + "name of entry function.">, + ]; +} + def SetDefaultParameters : Pass<"lwe-set-default-parameters"> { let summary = "Set default parameters for LWE ops"; let description = [{ diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index bf86400706..2115bc778c 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -8,6 +8,7 @@ #include "lib/Dialect/CKKS/Conversions/CKKSToLWE/CKKSToLWE.h" #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" #include "lib/Dialect/LWE/Transforms/AddClientInterface.h" +#include "lib/Dialect/LWE/Transforms/AddDebugPort.h" #include "lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.h" #include "lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.h" #include "lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.h" @@ -229,6 +230,13 @@ RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(const RLWEScheme scheme) { exit(EXIT_FAILURE); } + // insert debug handler calls + if (options.debug) { + lwe::AddDebugPortOptions addDebugPortOptions; + addDebugPortOptions.entryFunction = options.entryFunction; + pm.addPass(lwe::createAddDebugPort(addDebugPortOptions)); + } + // Convert to OpenFHE pm.addPass(lwe::createLWEToOpenfhe()); diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.h b/lib/Pipelines/ArithmeticPipelineRegistration.h index 06ba75b924..1300343832 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.h +++ b/lib/Pipelines/ArithmeticPipelineRegistration.h @@ -43,6 +43,11 @@ struct MlirToRLWEPipelineOptions llvm::cl::desc("Modulus switching right before the first multiplication " "(default to false)"), llvm::cl::init(false)}; + PassOptions::Option debug{ + *this, "insert-debug-handler-calls", + llvm::cl::desc("Insert function calls to an externally-defined debug " + "function (cf. --lwe-add-debug-port)"), + llvm::cl::init(false)}; }; using RLWEPipelineBuilder = diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index 0552bd70fb..28c1f9eaaa 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -91,6 +91,7 @@ cc_library( "@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe", "@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial", "@heir//lib/Dialect/LWE/Transforms:AddClientInterface", + "@heir//lib/Dialect/LWE/Transforms:AddDebugPort", "@heir//lib/Dialect/Lattigo/Transforms:ConfigureCryptoContext", "@heir//lib/Dialect/LinAlg/Conversions/LinalgToTensorExt", "@heir//lib/Dialect/Openfhe/Transforms:ConfigureCryptoContext", diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp index 162050d9b2..bce5f5a793 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.cpp @@ -80,7 +80,7 @@ LogicalResult OpenFhePkeEmitter::translate(Operation &op) { // Builtin ops .Case([&](auto op) { return printOperation(op); }) // Func ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) // Arith ops .Case 1) { return emitError(funcOp.getLoc(), llvm::formatv("Only functions with a single return type " "are supported, but this function has ", @@ -130,13 +137,17 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) { return failure(); } - Type result = funcOp.getResultTypes()[0]; - if (failed(emitType(result, funcOp->getLoc()))) { - return emitError(funcOp.getLoc(), - llvm::formatv("Failed to emit type {0}", result)); + if (funcOp.getNumResults() == 1) { + Type result = funcOp.getResultTypes()[0]; + if (failed(emitType(result, funcOp->getLoc()))) { + return emitError(funcOp.getLoc(), + llvm::formatv("Failed to emit type {0}", result)); + } + } else { + os << "void"; } - os << " " << funcOp.getName() << "("; + os << " " << canonicalizeDebugPort(funcOp.getName()) << "("; os.indent(); // Check the types without printing to enable failure outside of @@ -150,12 +161,27 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) { } } - os << commaSeparatedValues(funcOp.getArguments(), [&](Value value) { - return convertType(value.getType(), funcOp->getLoc()).value() + " " + - variableNames->getNameForValue(value); - }); + if (funcOp.getVisibility() == SymbolTable::Visibility::Private) { + // function declaration + os << commaSeparatedTypes(funcOp.getArgumentTypes(), [&](Type type) { + return convertType(type, funcOp->getLoc()).value(); + }); + } else { + os << commaSeparatedValues(funcOp.getArguments(), [&](Value value) { + return convertType(value.getType(), funcOp->getLoc()).value() + " " + + variableNames->getNameForValue(value); + }); + } os.unindent(); - os << ") {\n"; + os << ")"; + + // function declaration + if (funcOp.getVisibility() == SymbolTable::Visibility::Private) { + os << ";\n"; + return success(); + } + + os << " {\n"; os.indent(); for (Block &block : funcOp.getBlocks()) { @@ -171,6 +197,23 @@ LogicalResult OpenFhePkeEmitter::printOperation(func::FuncOp funcOp) { return success(); } +LogicalResult OpenFhePkeEmitter::printOperation(func::CallOp op) { + if (op.getNumResults() > 1) { + return emitError(op.getLoc(), "Only one return value supported"); + } + + if (op.getNumResults() != 0) { + os << variableNames->getNameForValue(op.getResult(0)) << " = "; + } + + os << canonicalizeDebugPort(op.getCallee()) << "("; + os << commaSeparatedValues(op.getOperands(), [&](Value value) { + return variableNames->getNameForValue(value); + }); + os << ");\n"; + return success(); +} + LogicalResult OpenFhePkeEmitter::printOperation(func::ReturnOp op) { if (op.getNumOperands() != 1) { return emitError(op.getLoc(), "Only one return value supported"); diff --git a/lib/Target/OpenFhePke/OpenFhePkeEmitter.h b/lib/Target/OpenFhePke/OpenFhePkeEmitter.h index 9a7f8ce29b..7999be1705 100644 --- a/lib/Target/OpenFhePke/OpenFhePkeEmitter.h +++ b/lib/Target/OpenFhePke/OpenFhePkeEmitter.h @@ -60,6 +60,7 @@ class OpenFhePkeEmitter { LogicalResult printOperation(::mlir::tensor::InsertOp op); LogicalResult printOperation(::mlir::tensor::SplatOp op); LogicalResult printOperation(::mlir::func::FuncOp op); + LogicalResult printOperation(::mlir::func::CallOp op); LogicalResult printOperation(::mlir::func::ReturnOp op); LogicalResult printOperation(::mlir::heir::lwe::RLWEDecodeOp op); LogicalResult printOperation( @@ -100,6 +101,9 @@ class OpenFhePkeEmitter { // Emit an OpenFhe type LogicalResult emitType(::mlir::Type type, ::mlir::Location loc); + // Canonicalize Debug Port + ::llvm::StringRef canonicalizeDebugPort(::llvm::StringRef debugPortName); + void emitAutoAssignPrefix(::mlir::Value result); LogicalResult emitTypedAssignPrefix(::mlir::Value result, ::mlir::Location loc); diff --git a/tests/Examples/openfhe/BUILD b/tests/Examples/openfhe/BUILD index b2116db73b..e2cc038247 100644 --- a/tests/Examples/openfhe/BUILD +++ b/tests/Examples/openfhe/BUILD @@ -32,6 +32,15 @@ openfhe_end_to_end_test( test_src = "dot_product_8_test.cpp", ) +openfhe_end_to_end_test( + name = "dot_product_8_debug_test", + generated_lib_header = "dot_product_8_debug_lib.h", + heir_opt_flags = ["--mlir-to-openfhe-bgv=entry-function=dot_product ciphertext-degree=8 insert-debug-handler-calls=true"], + mlir_src = "dot_product_8.mlir", + tags = ["notap"], + test_src = "dot_product_8_debug_test.cpp", +) + openfhe_end_to_end_test( name = "box_blur_64x64_test", generated_lib_header = "box_blur_64x64_lib.h", diff --git a/tests/Examples/openfhe/dot_product_8_debug_test.cpp b/tests/Examples/openfhe/dot_product_8_debug_test.cpp new file mode 100644 index 0000000000..85ba13b09a --- /dev/null +++ b/tests/Examples/openfhe/dot_product_8_debug_test.cpp @@ -0,0 +1,47 @@ +#include +#include + +#include "gtest/gtest.h" // from @googletest +#include "src/pke/include/openfhe.h" // from @openfhe + +// Generated headers (block clang-format from messing up order) +#include "tests/Examples/openfhe/dot_product_8_debug_lib.h" + +void __heir_debug(CryptoContextT cc, PrivateKeyT sk, CiphertextT ct) { + PlaintextT ptxt; + cc->Decrypt(sk, ct, &ptxt); + ptxt->SetLength(8); + std::cout << ptxt << std::endl; +} + +namespace mlir { +namespace heir { +namespace openfhe { + +TEST(DotProduct8Test, RunTest) { + auto cryptoContext = dot_product__generate_crypto_context(); + auto keyPair = cryptoContext->KeyGen(); + auto publicKey = keyPair.publicKey; + auto secretKey = keyPair.secretKey; + cryptoContext = + dot_product__configure_crypto_context(cryptoContext, secretKey); + + std::vector arg0 = {1, 2, 3, 4, 5, 6, 7, 8}; + std::vector arg1 = {2, 3, 4, 5, 6, 7, 8, 9}; + int64_t expected = 240; + + auto arg0Encrypted = + dot_product__encrypt__arg0(cryptoContext, arg0, publicKey); + auto arg1Encrypted = + dot_product__encrypt__arg1(cryptoContext, arg1, publicKey); + auto outputEncrypted = + dot_product(cryptoContext, secretKey, arg0Encrypted, arg1Encrypted); + auto actual = + dot_product__decrypt__result0(cryptoContext, outputEncrypted, secretKey); + + EXPECT_EQ(expected, actual); +} + +} // namespace openfhe +} // namespace heir +} // namespace mlir