Skip to content

Commit

Permalink
Merge pull request #271 from j2kun:cggi-to-tfhe-rs-2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584148844
  • Loading branch information
copybara-github committed Nov 20, 2023
2 parents 6291bfd + fd71c10 commit 041b41a
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 3 deletions.
1 change: 1 addition & 0 deletions include/Conversion/CGGIToTfheRust/CGGIToTfheRust.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include "mlir/Pass/PassBase.td"
def CGGIToTfheRust : Pass<"cggi-to-tfhe-rust"> {
let summary = "Lower `cggi` to `tfhe_rust` dialect.";
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::cggi::CGGIDialect",
"mlir::heir::lwe::LWEDialect",
"mlir::heir::tfhe_rust::TfheRustDialect",
Expand Down
2 changes: 1 addition & 1 deletion include/Dialect/CGGI/IR/CGGIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CGGI_BinaryGateOp<string mnemonic>
let results = (outs LWECiphertext:$output);
// Note: error: type of result #0, named 'output', is not buildable and a buildable type cannot be inferred
// LWECiphertext is not buildable?
let assemblyFormat = "operands attr-dict `->` qualified(type($output))" ;
let assemblyFormat = "operands attr-dict `:` qualified(type($output))" ;
}

def CGGI_AndOp : CGGI_BinaryGateOp<"and"> { let summary = "Logical AND of two ciphertexts."; }
Expand Down
12 changes: 12 additions & 0 deletions include/Dialect/TfheRust/IR/TfheRustOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def AddOp : TfheRust_Op<"add", [
let results = (outs TfheRust_CiphertextType:$output);
}

def SubOp : TfheRust_Op<"sub", [
Pure,
AllTypesMatch<["lhs", "rhs", "output"]>
]> {
let arguments = (ins
TfheRust_ServerKey:$serverKey,
TfheRust_CiphertextType:$lhs,
TfheRust_CiphertextType:$rhs
);
let results = (outs TfheRust_CiphertextType:$output);
}

def ApplyLookupTableOp : TfheRust_Op<"apply_lookup_table", [
Pure,
AllTypesMatch<["input", "output"]>
Expand Down
135 changes: 133 additions & 2 deletions lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ namespace mlir::heir {
#define GEN_PASS_DEF_CGGITOTFHERUST
#include "include/Conversion/CGGIToTfheRust/CGGIToTfheRust.h.inc"

constexpr int kBinaryGateLutWidth = 4;
constexpr int kAndLut = 8;
constexpr int kOrLut = 14;
constexpr int kXorLut = 6;

Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) {
// Only supporting unsigned types because the LWE dialect does not have a
// notion of signedness.
Expand Down Expand Up @@ -200,6 +205,131 @@ struct ConvertLut3Op : public OpConversionPattern<cggi::Lut3Op> {
}
};

struct ConvertLut2Op : public OpConversionPattern<cggi::Lut2Op> {
ConvertLut2Op(mlir::MLIRContext *context)
: OpConversionPattern<cggi::Lut2Op>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::Lut2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
FailureOr<Value> result = getContextualServerKey(op.getOperation());
if (failed(result)) return result;

Value serverKey = result.value();
// A followup -cse pass should combine repeated LUT generation ops.
auto lut = b.create<tfhe_rust::GenerateLookupTableOp>(
serverKey, adaptor.getLookupTable());
// Construct input = b << 1 + a
auto shiftedB = b.create<tfhe_rust::ScalarLeftShiftOp>(
serverKey, adaptor.getB(),
b.create<arith::ConstantOp>(b.getI8Type(), b.getI8IntegerAttr(1))
.getResult());
auto summedBA =
b.create<tfhe_rust::AddOp>(serverKey, shiftedB, adaptor.getA());

rewriter.replaceOp(
op, b.create<tfhe_rust::ApplyLookupTableOp>(serverKey, summedBA, lut));
return success();
}
};

LogicalResult replaceBinaryGate(Operation *op, Value lhs, Value rhs,
ConversionPatternRewriter &rewriter, int lut) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
FailureOr<Value> result = getContextualServerKey(op);
if (failed(result)) return result;

Value serverKey = result.value();
// A followup -cse pass should combine repeated LUT generation ops.
auto lookupTable = b.getIntegerAttr(
b.getIntegerType(kBinaryGateLutWidth, /*isSigned=*/false), lut);
auto lutOp =
b.create<tfhe_rust::GenerateLookupTableOp>(serverKey, lookupTable);
// Construct input = rhs << 1 + lhs
auto shiftedRhs = b.create<tfhe_rust::ScalarLeftShiftOp>(
serverKey, rhs,
b.create<arith::ConstantOp>(b.getI8Type(), b.getI8IntegerAttr(1))
.getResult());
auto input = b.create<tfhe_rust::AddOp>(serverKey, shiftedRhs, lhs);
rewriter.replaceOp(
op, b.create<tfhe_rust::ApplyLookupTableOp>(serverKey, input, lutOp));
return success();
}

struct ConvertAndOp : public OpConversionPattern<cggi::AndOp> {
ConvertAndOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::AndOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::AndOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return replaceBinaryGate(op.getOperation(), adaptor.getLhs(),
adaptor.getRhs(), rewriter, kAndLut);
}
};

struct ConvertOrOp : public OpConversionPattern<cggi::OrOp> {
ConvertOrOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::OrOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::OrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return replaceBinaryGate(op.getOperation(), adaptor.getLhs(),
adaptor.getRhs(), rewriter, kOrLut);
}
};

struct ConvertXorOp : public OpConversionPattern<cggi::XorOp> {
ConvertXorOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::XorOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::XorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return replaceBinaryGate(op.getOperation(), adaptor.getLhs(),
adaptor.getRhs(), rewriter, kXorLut);
}
};

struct ConvertNotOp : public OpConversionPattern<cggi::NotOp> {
ConvertNotOp(mlir::MLIRContext *context)
: OpConversionPattern<cggi::NotOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
cggi::NotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
FailureOr<Value> result = getContextualServerKey(op);
if (failed(result)) return result;
Value serverKey = result.value();

auto width = widthFromEncodingAttr(op.getInput().getType().getEncoding());
auto cleartextType = b.getIntegerType(width);
auto outputType = encrytpedUIntTypeFromWidth(getContext(), width);
// not(x) == trivial_encryption(1) - x
auto createTrivialOp = rewriter.create<tfhe_rust::CreateTrivialOp>(
op.getLoc(), outputType, serverKey,
b.create<arith::ConstantOp>(cleartextType,
b.getIntegerAttr(cleartextType, 1))
.getResult());
rewriter.replaceOp(op, b.create<tfhe_rust::SubOp>(
serverKey, createTrivialOp, adaptor.getInput()));
return success();
}
};

struct ConvertTrivialEncryptOp
: public OpConversionPattern<lwe::TrivialEncryptOp> {
ConvertTrivialEncryptOp(mlir::MLIRContext *context)
Expand Down Expand Up @@ -275,8 +405,9 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase<CGGIToTfheRust> {

// FIXME: still need to update callers to insert the new server key arg, if
// needed and possible.
patterns.add<AddServerKeyArg, ConvertLut3Op, ConvertEncodeOp,
ConvertTrivialEncryptOp>(typeConverter, context);
patterns.add<AddServerKeyArg, ConvertAndOp, ConvertEncodeOp, ConvertLut2Op,
ConvertLut3Op, ConvertNotOp, ConvertOrOp,
ConvertTrivialEncryptOp, ConvertXorOp>(typeConverter, context);

if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
return signalPassFailure();
Expand Down
35 changes: 35 additions & 0 deletions tests/cggi_to_tfhe_rust/binary_gates.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// RUN: heir-opt --cggi-to-tfhe-rust -cse %s | FileCheck %s

#encoding = #lwe.unspecified_bit_field_encoding<cleartext_bitwidth = 3>
!ct_ty = !lwe.lwe_ciphertext<encoding = #encoding>
!pt_ty = !lwe.lwe_plaintext<encoding = #encoding>
// CHECK-LABEL: @binary_gates
// CHECK-SAME: %[[sks:.*]]: [[sks_ty:!tfhe_rust.server_key]], %[[arg1:.*]]: [[ct_ty:!tfhe_rust.eui3]], %[[arg2:.*]]: [[ct_ty]]
func.func @binary_gates(%arg1: !ct_ty, %arg2: !ct_ty) -> (!ct_ty) {
// CHECK: %[[v0:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 8 : ui4}
// CHECK: %[[shiftAmount:.*]] = arith.constant 1 : i8

// CHECK: %[[v1:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[arg2]], %[[shiftAmount]]
// CHECK: %[[v2:.*]] = tfhe_rust.add %[[sks]], %[[v1]], %[[arg1]]
// CHECK: %[[v3:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v2]], %[[v0]]
%0 = cggi.and %arg1, %arg2 : !ct_ty

// CHECK: %[[v4:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 14 : ui4}
// CHECK: %[[v5:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v2]], %[[v4]]
// (reuses shifted inputs from the AND)
%1 = cggi.or %arg1, %arg2 : !ct_ty

// CHECK: %[[notConst:.*]] = arith.constant 1 : i3
// CHECK: %[[v6:.*]] = tfhe_rust.create_trivial %[[sks]], %[[notConst]]
// CHECK: %[[v7:.*]] = tfhe_rust.sub %[[sks]], %[[v6]], %[[v5]]
%2 = cggi.not %1 : !ct_ty

// CHECK: %[[v8:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 6 : ui4}
// CHECK: %[[v9:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[v3]], %[[shiftAmount]]
// CHECK: %[[v10:.*]] = tfhe_rust.add %[[sks]], %[[v9]], %[[v7]]
// CHECK: %[[v11:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v10]], %[[v8]]
%3 = cggi.xor %2, %0 : !ct_ty

// CHECK: return %[[v11]]
return %3 : !ct_ty
}

0 comments on commit 041b41a

Please sign in to comment.