Skip to content

Commit

Permalink
Merge pull request #251 from asraa:more-truth-table
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 582068988
  • Loading branch information
copybara-github committed Nov 13, 2023
2 parents 8bfa175 + 1dfd60e commit a5b965c
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 1 deletion.
127 changes: 126 additions & 1 deletion lib/Conversion/CombToCGGI/CombToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
#include "include/Dialect/CGGI/IR/CGGIOps.h"
#include "include/Dialect/Comb/IR/CombDialect.h"
#include "include/Dialect/Comb/IR/CombOps.h"
#include "include/Dialect/LWE/IR/LWEAttributes.h"
#include "include/Dialect/LWE/IR/LWEOps.h"
#include "include/Dialect/LWE/IR/LWETypes.h"
#include "include/Dialect/Secret/IR/SecretDialect.h"
#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/Secret/IR/SecretTypes.h"
#include "lib/Conversion/Utils.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
Expand Down Expand Up @@ -34,16 +37,138 @@ class SecretTypeConverter : public TypeConverter {
}
};

class SecretGenericOpTypeConversion
: public OpConversionPattern<secret::GenericOp> {
public:
using OpConversionPattern<secret::GenericOp>::OpConversionPattern;

LogicalResult matchAndRewrite(
secret::GenericOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Block *originalBlock = op->getBlock();
Block &opEntryBlock = op.getRegion().front();

secret::YieldOp yieldOp =
dyn_cast<secret::YieldOp>(op.getRegion().back().getTerminator());

// Split the parent block of the generic op, so that we have a
// clear insertion point for inlining.
Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));

// mergeBlocks does not replace the original block values with the inputs to
// secret.generic, so we manually replace them here. This lifts the internal
// plaintext integer values within the secret.generic body to their original
// secret values.
auto genericInputs = op.getInputs();
for (int i = 0; i < opEntryBlock.getNumArguments(); i++) {
rewriter.replaceAllUsesWith(opEntryBlock.getArgument(i),
genericInputs[i]);
}

// In addition to lifting the plaintext arguments, we also lift the output
// arguments to secrets. This is required for any truth tables that have
// secret inputs.
// For some reason, if this doesn't occur, the type conversion framework is
// unable to update the uses of converted truth table results.
rewriter.startRootUpdate(op);
opEntryBlock.walk([&](comb::TruthTableOp op) {
bool ciphertextArg =
std::any_of(op.getOperands().begin(), op.getOperands().end(),
[&](const Value &val) {
return isa<secret::SecretType>(val.getType());
});
if (ciphertextArg) {
op->getResults()[0].setType(lwe::LWECiphertextType::get(
getContext(),
lwe::UnspecifiedBitFieldEncodingAttr::get(
getContext(), op.getResult().getType().getWidth()),
lwe::LWEParamsAttr()));
}
});
rewriter.finalizeRootUpdate(op);

// Inline the secret.generic internal region, moving all of the operations
// to the parent region.
rewriter.inlineRegionBefore(op.getRegion(), newBlock);
rewriter.replaceOp(op, yieldOp->getOperands());
rewriter.mergeBlocks(&opEntryBlock, originalBlock, genericInputs);
rewriter.mergeBlocks(newBlock, originalBlock, {});

rewriter.eraseOp(yieldOp);
return success();
}
};

// ConvertTruthTableOp converts op arguments to trivially encoded LWE
// ciphertexts when at least one argument is an LWE ciphertext.
struct ConvertTruthTableOp : public OpConversionPattern<TruthTableOp> {
ConvertTruthTableOp(mlir::MLIRContext *context)
: OpConversionPattern<TruthTableOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
TruthTableOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (op->getNumOperands() != 3) {
op->emitError() << "expected 3 truth table arguments to lower to CGGI";
}

MLIRContext *ctx = getContext();
bool ciphertextArg =
std::any_of(adaptor.getOperands().begin(), adaptor.getOperands().end(),
[&](const Value &val) {
return isa<lwe::LWECiphertextType>(val.getType()) ||
isa<secret::SecretType>(val.getType());
});

SmallVector<mlir::Value, 4> lutInputs;
for (Value val : adaptor.getOperands()) {
auto integerTy = dyn_cast<IntegerType>(val.getType());
// If any of the arguments to the truth table are ciphertexts, we must
// encode and trivially encrypt the plaintext integers arguments.
if (ciphertextArg && integerTy) {
auto encoding = lwe::UnspecifiedBitFieldEncodingAttr::get(
ctx, integerTy.getWidth());
auto ptxtTy = lwe::LWEPlaintextType::get(ctx, encoding);
auto ctxtTy =
lwe::LWECiphertextType::get(ctx, encoding, lwe::LWEParamsAttr());

lutInputs.push_back(rewriter.create<lwe::TrivialEncryptOp>(
op.getLoc(), ctxtTy,
rewriter.create<lwe::EncodeOp>(op.getLoc(), ptxtTy, val, encoding),
lwe::LWEParamsAttr()));
} else {
lutInputs.push_back(val);
}
}

rewriter.replaceOp(op, rewriter.create<cggi::Lut3Op>(
op.getLoc(), lutInputs[0], lutInputs[1],
lutInputs[2], op.getLookupTable()));

return success();
}
};

struct CombToCGGI : public impl::CombToCGGIBase<CombToCGGI> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto *module = getOperation();
SecretTypeConverter typeConverter(context);

RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalOp<ModuleOp>();

RewritePatternSet patterns(context);
patterns.add<ConvertTruthTableOp>(typeConverter, context);
target.addIllegalOp<TruthTableOp>();

patterns.add<SecretGenericOpTypeConversion>(typeConverter,
patterns.getContext());
target.addDynamicallyLegalOp<secret::GenericOp>(
[&](secret::GenericOp op) { return typeConverter.isLegal(op); });

addStructuralConversionPatterns(typeConverter, patterns, target);

if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
Expand Down
38 changes: 38 additions & 0 deletions tests/comb_to_cggi/add_one.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: heir-opt --comb-to-cggi -cse %s | FileCheck %s

// TODO(https://github.com/google/heir/issues/244): The Yosys optimizer pass
// still needs functionality to split multi-bit inputs and outputs, this test
// was generated by manually performing the splitting on the output of
// tests/yosys_optimizer/add_one.mlir.

module {
// This function computes add_one to an i8 input split into bits and returns
// the resulting i8 split into bits.
// CHECK-LABEL: add_one
// CHECK-NOT: secret
func.func @add_one(%arg00: !secret.secret<i1>, %arg01: !secret.secret<i1>, %arg02: !secret.secret<i1>, %arg03: !secret.secret<i1>, %arg04: !secret.secret<i1>, %arg05: !secret.secret<i1>, %arg06: !secret.secret<i1>, %arg07: !secret.secret<i1>) -> (!secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>) {
// CHECK: [[FALSE:%.+]] = arith.constant false
// CHECK: [[ENCFALSE:%.+]] = lwe.encode [[FALSE]]
// CHECK: [[LWEFALSE:%.+]] = lwe.trivial_encrypt [[ENCFALSE]]
// CHECK-COUNT-11: cggi.lut3
%false = arith.constant false
%0:8 = secret.generic
ins(%false, %arg00, %arg01, %arg02, %arg03, %arg04, %arg05, %arg06, %arg07 : i1, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>) {
^bb0(%FALSE : i1, %ARG00: i1, %ARG01: i1, %ARG02: i1, %ARG03: i1, %ARG04: i1, %ARG05: i1, %ARG06: i1, %ARG07: i1) :
%2 = comb.truth_table %ARG00, %ARG01, %FALSE -> 6 : ui8
%3 = comb.truth_table %ARG00, %FALSE, %FALSE -> 1 : ui8
%5 = comb.truth_table %ARG00, %ARG01, %ARG02 -> 120 : ui8
%6 = comb.truth_table %ARG00, %ARG01, %ARG02 -> 128 : ui8
%8 = comb.truth_table %6, %ARG03, %FALSE -> 6 : ui8
%10 = comb.truth_table %6, %ARG03, %ARG04 -> 120 : ui8
%11 = comb.truth_table %6, %ARG03, %ARG04 -> 128 : ui8
%13 = comb.truth_table %11, %ARG05, %FALSE -> 6 : ui8
%15 = comb.truth_table %11, %ARG05, %ARG06 -> 120 : ui8
%16 = comb.truth_table %11, %ARG05, %ARG06 -> 128 : ui8
%18 = comb.truth_table %16, %ARG07, %FALSE -> 6 : ui8
secret.yield %18, %15, %13, %10, %8, %5, %2, %3 : i1, i1, i1, i1, i1, i1, i1, i1
} -> (!secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>)
// CHECK: return
func.return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>, !secret.secret<i1>
}
}
15 changes: 15 additions & 0 deletions tests/comb_to_cggi/secret_generic.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: heir-opt --comb-to-cggi %s | FileCheck %s

module {
// CHECK-NOT: secret
// CHECK: @truth_table([[ARG:%.*]]: [[LWET:!lwe.lwe_ciphertext<.*>]]) -> [[LWET]]
func.func @truth_table(%arg0: !secret.secret<i1>) -> !secret.secret<i1> {
%0 = secret.generic
ins(%arg0 : !secret.secret<i1>) {
^bb0(%ARG0: i1) :
secret.yield %ARG0 : i1
} -> (!secret.secret<i1>)
// CHECK: return [[ARG]] : [[LWET]]
func.return %0 : !secret.secret<i1>
}
}
41 changes: 41 additions & 0 deletions tests/comb_to_cggi/truth_table.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: heir-opt --split-input-file --comb-to-cggi --cse %s | FileCheck %s

// CHECK-NOT: secret
// CHECK: @truth_table_all_secret([[ARG:%.*]]: [[LWET:!lwe.lwe_ciphertext<.*>]]) -> [[LWET]]
func.func @truth_table_all_secret(%arg0: !secret.secret<i1>) -> !secret.secret<i1> {
// CHECK: [[VAL:%.+]] = cggi.lut3([[ARG]], [[ARG]], [[ARG]])
%0 = secret.generic
ins(%arg0: !secret.secret<i1>) {
^bb0(%ARG0: i1) :
%1 = comb.truth_table %ARG0, %ARG0, %ARG0 -> 6 : ui8
secret.yield %1 : i1
} -> (!secret.secret<i1>)
// CHECK: return [[VAL]] : [[LWET]]
func.return %0 : !secret.secret<i1>
}

// -----

// CHECK-NOT: secret
// CHECK: @truth_table_partial_secret([[ARG:%.*]]: [[LWET:!lwe.lwe_ciphertext<.*>]]) -> [[LWET]]
func.func @truth_table_partial_secret(%arg0: !secret.secret<i1>) -> !secret.secret<i1> {
// CHECK: [[FALSE:%.+]] = arith.constant false
%false = arith.constant false
// CHECK: [[TRUE:%.+]] = arith.constant true
%true = arith.constant true
// CHECK: [[ENCFALSE:%.+]] = lwe.encode [[FALSE]]
// CHECK: [[LWEFALSE:%.+]] = lwe.trivial_encrypt [[ENCFALSE]]
// CHECK: [[ENCTRUE:%.+]] = lwe.encode [[TRUE]]
// CHECK: [[LWETRUE:%.+]] = lwe.trivial_encrypt [[ENCTRUE]]
// CHECK: [[VAL1:%.+]] = cggi.lut3([[LWEFALSE]], [[LWETRUE]], [[ARG]])
// CHECK: [[VAL2:%.+]] = cggi.lut3([[LWEFALSE]], [[LWETRUE]], [[VAL1]])
%0 = secret.generic
ins(%false, %true, %arg0: i1, i1, !secret.secret<i1>) {
^bb0(%FALSE: i1, %TRUE: i1, %ARG0: i1) :
%1 = comb.truth_table %FALSE, %TRUE, %ARG0 -> 6 : ui8
%2 = comb.truth_table %FALSE, %TRUE, %1 -> 2 : ui8
secret.yield %2 : i1
} -> (!secret.secret<i1>)
// CHECK: return [[VAL2]] : [[LWET]]
func.return %0 : !secret.secret<i1>
}

0 comments on commit a5b965c

Please sign in to comment.