Skip to content

Commit

Permalink
Add external debug port during execution
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Jan 25, 2025
1 parent 3f984f7 commit 9afa3e8
Show file tree
Hide file tree
Showing 13 changed files with 390 additions and 20 deletions.
82 changes: 74 additions & 8 deletions lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ FailureOr<Value> getContextualCryptoContext(Operation *op) {
return result.value();
}

// NOTE: we can not use containsDialect
// for FuncOp declaration, which does not have a body
template <typename... Dialects>
bool containsArgumentOfDialect(func::FuncOp funcOp) {
return llvm::any_of(funcOp.getArgumentTypes(), [&](Type argType) {
return DialectEqual<Dialects...>()(&argType.getDialect());
});
}

struct AddCryptoContextArg : public OpConversionPattern<func::FuncOp> {
AddCryptoContextArg(mlir::MLIRContext *context)
: OpConversionPattern<func::FuncOp>(context, /* benefit= */ 2) {}
Expand All @@ -68,8 +77,13 @@ struct AddCryptoContextArg : public OpConversionPattern<func::FuncOp> {
LogicalResult matchAndRewrite(
func::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!containsDialects<lwe::LWEDialect, bgv::BGVDialect, ckks::CKKSDialect>(
op)) {
auto containsCryptoOps =
containsDialects<lwe::LWEDialect, bgv::BGVDialect, ckks::CKKSDialect>(
op);
auto containsCryptoArg =
containsArgumentOfDialect<lwe::LWEDialect, bgv::BGVDialect,
ckks::CKKSDialect>(op);
if (!(containsCryptoOps || containsCryptoArg)) {
return failure();
}

Expand All @@ -86,15 +100,47 @@ struct AddCryptoContextArg : public OpConversionPattern<func::FuncOp> {
rewriter.modifyOpInPlace(op, [&] {
op.setType(newFuncType);

Block &block = op.getBody().getBlocks().front();
block.insertArgument(&block.getArguments().front(), cryptoContextType,
op.getLoc());
// guard against private FuncOp (i.e. declaration)
if (op.getVisibility() != SymbolTable::Visibility::Private) {
Block &block = op.getBody().getBlocks().front();
block.insertArgument(&block.getArguments().front(), cryptoContextType,
op.getLoc());
}
});

return success();
}
};

struct ConvertFuncCallOp : public OpConversionPattern<func::CallOp> {
ConvertFuncCallOp(mlir::MLIRContext *context)
: OpConversionPattern<func::CallOp>(context) {}

using OpConversionPattern<func::CallOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
func::CallOp op, typename func::CallOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<Value> result = getContextualCryptoContext(op.getOperation());
if (failed(result)) return result;
auto cryptoContext = result.value();

auto callee = op.getCallee();
auto operands = adaptor.getOperands();
auto resultTypes = op.getResultTypes();

SmallVector<Value> newOperands;
newOperands.push_back(cryptoContext);
for (auto operand : operands) {
newOperands.push_back(operand);
}

rewriter.replaceOpWithNewOp<func::CallOp>(op, callee, resultTypes,
newOperands);
return success();
}
};

struct ConvertEncryptOp : public OpConversionPattern<lwe::RLWEEncryptOp> {
ConvertEncryptOp(mlir::MLIRContext *context)
: OpConversionPattern<lwe::RLWEEncryptOp>(context) {}
Expand Down Expand Up @@ -275,11 +321,28 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase<LWEToOpenfhe> {
bool hasCryptoContextArg = op.getFunctionType().getNumInputs() > 0 &&
mlir::isa<openfhe::CryptoContextType>(
*op.getFunctionType().getInputs().begin());
auto containsCryptoOps =
containsDialects<lwe::LWEDialect, bgv::BGVDialect, ckks::CKKSDialect>(
op);
auto containsCryptoArg =
containsArgumentOfDialect<lwe::LWEDialect, bgv::BGVDialect,
ckks::CKKSDialect>(op);
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody()) &&
(!containsDialects<lwe::LWEDialect, bgv::BGVDialect,
ckks::CKKSDialect>(op) ||
hasCryptoContextArg);
(!(containsCryptoOps || containsCryptoArg) || hasCryptoContextArg);
});

// Ensures that callee function signature is consistent
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
auto operandTypes = op.getCalleeType().getInputs();
auto containsCryptoArg = llvm::any_of(operandTypes, [&](Type argType) {
return DialectEqual<lwe::LWEDialect, bgv::BGVDialect,
ckks::CKKSDialect>()(&argType.getDialect());
});
auto hasCryptoContextArg =
!operandTypes.empty() &&
mlir::isa<openfhe::CryptoContextType>(*operandTypes.begin());
return (!containsCryptoArg || hasCryptoContextArg);
});

patterns.add<
Expand All @@ -290,6 +353,9 @@ struct LWEToOpenfhe : public impl::LWEToOpenfheBase<LWEToOpenfhe> {
// Update Func Op Signature
AddCryptoContextArg,

// Update Func CallOp Signature
ConvertFuncCallOp,

// Handle LWE encode and en/decrypt
// Note: `lwe.decode` is handled directly by the OpenFHE emitter
ConvertEncodeOp, ConvertEncryptOp, ConvertDecryptOp,
Expand Down
118 changes: 118 additions & 0 deletions lib/Dialect/LWE/Transforms/AddDebugPort.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#include "lib/Dialect/LWE/Transforms/AddDebugPort.h"

#include "lib/Dialect/LWE/IR/LWEOps.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace lwe {

#define GEN_PASS_DEF_ADDDEBUGPORT
#include "lib/Dialect/LWE/Transforms/Passes.h.inc"

FailureOr<Type> getPrivateKeyType(func::FuncOp op) {
const auto *type = llvm::find_if(op.getArgumentTypes(), [](Type type) {
return mlir::isa<NewLWECiphertextType>(type);
});

if (type == op.getArgumentTypes().end()) {
return op.emitError(
"Function does not have an argument of LWECiphertextType");
}

auto lweCiphertextType = cast<NewLWECiphertextType>(*type);

auto lwePrivateKeyType = NewLWESecretKeyType::get(
op.getContext(), lweCiphertextType.getKey(),
lweCiphertextType.getCiphertextSpace().getRing());
return lwePrivateKeyType;
}

func::FuncOp getOrCreateExternalDebugFunc(
ModuleOp module, Type lwePrivateKeyType,
NewLWECiphertextType lweCiphertextType,
const DenseMap<Type, int> &typeToInt) {
std::string funcName =
"__heir_debug_" + std::to_string(typeToInt.at(lweCiphertextType));

auto *context = module.getContext();
auto lookup = module.lookupSymbol<func::FuncOp>(funcName);
if (lookup) return lookup;

auto debugFuncType =
FunctionType::get(context, {lwePrivateKeyType, lweCiphertextType}, {});

ImplicitLocOpBuilder b =
ImplicitLocOpBuilder::atBlockBegin(module.getLoc(), module.getBody());
auto funcOp = b.create<func::FuncOp>(funcName, debugFuncType);
// required for external func call
funcOp.setPrivate();
return funcOp;
}

LogicalResult insertExternalCall(func::FuncOp op, Type lwePrivateKeyType) {
auto module = op->getParentOfType<ModuleOp>();

// map ciphertext type to unique int
DenseMap<Type, int> typeToInt;

// implicit assumption the first argument is private key
auto privateKey = op.getArgument(0);

ImplicitLocOpBuilder b =
ImplicitLocOpBuilder::atBlockBegin(module.getLoc(), module.getBody());
op.walk([&](Operation *op) {
b.setInsertionPointAfter(op);
for (Value result : op->getResults()) {
Type resultType = result.getType();
if (auto lweCiphertextType = dyn_cast<NewLWECiphertextType>(resultType)) {
// update typeToInt
if (!typeToInt.count(resultType)) {
typeToInt[resultType] = typeToInt.size();
}
b.create<func::CallOp>(
getOrCreateExternalDebugFunc(module, lwePrivateKeyType,
lweCiphertextType, typeToInt),
ArrayRef<Value>{privateKey, result});
}
}
return WalkResult::advance();
});
return success();
}

LogicalResult convertFunc(func::FuncOp op) {
auto type = getPrivateKeyType(op);
if (failed(type)) return failure();
auto lwePrivateKeyType = type.value();

op.insertArgument(0, lwePrivateKeyType, nullptr, op.getLoc());
if (failed(insertExternalCall(op, lwePrivateKeyType))) {
return failure();
}
return success();
}

struct AddDebugPort : impl::AddDebugPortBase<AddDebugPort> {
using AddDebugPortBase::AddDebugPortBase;

void runOnOperation() override {
auto result =
getOperation()->walk<WalkOrder::PreOrder>([&](func::FuncOp op) {
if (op.getSymName() == entryFunction && failed(convertFunc(op))) {
op->emitError("Failed to add client interface for func");
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (result.wasInterrupted()) signalPassFailure();
}
};
} // namespace lwe
} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions lib/Dialect/LWE/Transforms/AddDebugPort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef LIB_DIALECT_LWE_TRANSFORMS_ADDDEBUGPORT_H_
#define LIB_DIALECT_LWE_TRANSFORMS_ADDDEBUGPORT_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace lwe {

#define GEN_PASS_DECL_ADDDEBUGPORT
#include "lib/Dialect/LWE/Transforms/Passes.h.inc"

} // namespace lwe
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_LWE_TRANSFORMS_ADDDEBUGPORT_H_
18 changes: 18 additions & 0 deletions lib/Dialect/LWE/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cc_library(
],
deps = [
":AddClientInterface",
":AddDebugPort",
":SetDefaultParameters",
":pass_inc_gen",
"@heir//lib/Dialect/LWE/IR:Dialect",
Expand All @@ -35,6 +36,23 @@ cc_library(
],
)

cc_library(
name = "AddDebugPort",
srcs = ["AddDebugPort.cpp"],
hdrs = [
"AddDebugPort.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

cc_library(
name = "SetDefaultParameters",
srcs = ["SetDefaultParameters.cpp"],
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/LWE/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "lib/Dialect/LWE/IR/LWEDialect.h"
#include "lib/Dialect/LWE/Transforms/AddClientInterface.h"
#include "lib/Dialect/LWE/Transforms/AddDebugPort.h"
#include "lib/Dialect/LWE/Transforms/SetDefaultParameters.h"

namespace mlir {
Expand Down
33 changes: 33 additions & 0 deletions lib/Dialect/LWE/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,39 @@ def AddClientInterface : Pass<"lwe-add-client-interface"> {
];
}

def AddDebugPort : Pass<"lwe-add-debug-port"> {
let summary = "Add debug port to (R)LWE encrypted functions";
let description = [{
This pass adds debug ports to the specified function in the IR. The debug ports
are prefixed with "__heir_debug" and are invoked after each homomorphic operation in the
function. The debug ports are declarations and user should provide functions with
the same name in their code.

For example, if the function is called "foo", the secret key is added to its
arguments, and the debug port is called after each homomorphic operation:
```mlir
// declaration of external debug function
func.func private @__heir_debug(%sk : !sk, %ct : !ct)

// secret key added as function arg
func.func @foo(%sk : !sk, ...) {
%ct = lwe.radd ...
// invoke external debug function
__heir_debug(%sk, %ct)
%ct1 = lwe.rmul ...
__heir_debug(%sk, %ct1)
...
}
```
}];
let dependentDialects = ["mlir::heir::lwe::LWEDialect"];
let options = [
Option<"entryFunction", "entry-function", "std::string",
/*default=*/"", "Default entry function "
"name of entry function.">,
];
}

def SetDefaultParameters : Pass<"lwe-set-default-parameters"> {
let summary = "Set default parameters for LWE ops";
let description = [{
Expand Down
8 changes: 8 additions & 0 deletions lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "lib/Dialect/CKKS/Conversions/CKKSToLWE/CKKSToLWE.h"
#include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h"
#include "lib/Dialect/LWE/Transforms/AddClientInterface.h"
#include "lib/Dialect/LWE/Transforms/AddDebugPort.h"
#include "lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.h"
#include "lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.h"
#include "lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.h"
Expand Down Expand Up @@ -229,6 +230,13 @@ RLWEPipelineBuilder mlirToOpenFheRLWEPipelineBuilder(const RLWEScheme scheme) {
exit(EXIT_FAILURE);
}

// insert debug handler calls
if (options.debug) {
lwe::AddDebugPortOptions addDebugPortOptions;
addDebugPortOptions.entryFunction = options.entryFunction;
pm.addPass(lwe::createAddDebugPort(addDebugPortOptions));
}

// Convert to OpenFHE
pm.addPass(lwe::createLWEToOpenfhe());

Expand Down
5 changes: 5 additions & 0 deletions lib/Pipelines/ArithmeticPipelineRegistration.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ struct MlirToRLWEPipelineOptions
llvm::cl::desc("Modulus switching right before the first multiplication "
"(default to false)"),
llvm::cl::init(false)};
PassOptions::Option<bool> debug{
*this, "insert-debug-handler-calls",
llvm::cl::desc("Insert function calls to an externally-defined debug "
"function (cf. --lwe-add-debug-port)"),
llvm::cl::init(false)};
};

using RLWEPipelineBuilder =
Expand Down
1 change: 1 addition & 0 deletions lib/Pipelines/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ cc_library(
"@heir//lib/Dialect/LWE/Conversions/LWEToOpenfhe",
"@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial",
"@heir//lib/Dialect/LWE/Transforms:AddClientInterface",
"@heir//lib/Dialect/LWE/Transforms:AddDebugPort",
"@heir//lib/Dialect/Lattigo/Transforms:ConfigureCryptoContext",
"@heir//lib/Dialect/LinAlg/Conversions/LinalgToTensorExt",
"@heir//lib/Dialect/Openfhe/Transforms:ConfigureCryptoContext",
Expand Down
Loading

0 comments on commit 9afa3e8

Please sign in to comment.