diff --git a/include/Conversion/CGGIToTfheRust/CGGIToTfheRust.td b/include/Conversion/CGGIToTfheRust/CGGIToTfheRust.td index 11220ebf0..b8b9739b1 100644 --- a/include/Conversion/CGGIToTfheRust/CGGIToTfheRust.td +++ b/include/Conversion/CGGIToTfheRust/CGGIToTfheRust.td @@ -6,6 +6,7 @@ include "mlir/Pass/PassBase.td" def CGGIToTfheRust : Pass<"cggi-to-tfhe-rust"> { let summary = "Lower `cggi` to `tfhe_rust` dialect."; let dependentDialects = [ + "mlir::arith::ArithDialect", "mlir::heir::cggi::CGGIDialect", "mlir::heir::lwe::LWEDialect", "mlir::heir::tfhe_rust::TfheRustDialect", diff --git a/include/Dialect/CGGI/IR/CGGIOps.td b/include/Dialect/CGGI/IR/CGGIOps.td index 1552ace33..66e374fda 100644 --- a/include/Dialect/CGGI/IR/CGGIOps.td +++ b/include/Dialect/CGGI/IR/CGGIOps.td @@ -27,7 +27,7 @@ class CGGI_BinaryGateOp let results = (outs LWECiphertext:$output); // Note: error: type of result #0, named 'output', is not buildable and a buildable type cannot be inferred // LWECiphertext is not buildable? - let assemblyFormat = "operands attr-dict `->` qualified(type($output))" ; + let assemblyFormat = "operands attr-dict `:` qualified(type($output))" ; } def CGGI_AndOp : CGGI_BinaryGateOp<"and"> { let summary = "Logical AND of two ciphertexts."; } diff --git a/include/Dialect/TfheRust/IR/TfheRustOps.td b/include/Dialect/TfheRust/IR/TfheRustOps.td index 68b965eff..77637a8c8 100644 --- a/include/Dialect/TfheRust/IR/TfheRustOps.td +++ b/include/Dialect/TfheRust/IR/TfheRustOps.td @@ -46,6 +46,18 @@ def AddOp : TfheRust_Op<"add", [ let results = (outs TfheRust_CiphertextType:$output); } +def SubOp : TfheRust_Op<"sub", [ + Pure, + AllTypesMatch<["lhs", "rhs", "output"]> +]> { + let arguments = (ins + TfheRust_ServerKey:$serverKey, + TfheRust_CiphertextType:$lhs, + TfheRust_CiphertextType:$rhs + ); + let results = (outs TfheRust_CiphertextType:$output); +} + def ApplyLookupTableOp : TfheRust_Op<"apply_lookup_table", [ Pure, AllTypesMatch<["input", "output"]> diff --git a/lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp index 5e6634f97..c7d4c921e 100644 --- a/lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Conversion/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -30,6 +30,11 @@ namespace mlir::heir { #define GEN_PASS_DEF_CGGITOTFHERUST #include "include/Conversion/CGGIToTfheRust/CGGIToTfheRust.h.inc" +constexpr int kBinaryGateLutWidth = 4; +constexpr int kAndLut = 8; +constexpr int kOrLut = 14; +constexpr int kXorLut = 6; + Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) { // Only supporting unsigned types because the LWE dialect does not have a // notion of signedness. @@ -200,6 +205,131 @@ struct ConvertLut3Op : public OpConversionPattern { } }; +struct ConvertLut2Op : public OpConversionPattern { + ConvertLut2Op(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::Lut2Op op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + FailureOr result = getContextualServerKey(op.getOperation()); + if (failed(result)) return result; + + Value serverKey = result.value(); + // A followup -cse pass should combine repeated LUT generation ops. + auto lut = b.create( + serverKey, adaptor.getLookupTable()); + // Construct input = b << 1 + a + auto shiftedB = b.create( + serverKey, adaptor.getB(), + b.create(b.getI8Type(), b.getI8IntegerAttr(1)) + .getResult()); + auto summedBA = + b.create(serverKey, shiftedB, adaptor.getA()); + + rewriter.replaceOp( + op, b.create(serverKey, summedBA, lut)); + return success(); + } +}; + +LogicalResult replaceBinaryGate(Operation *op, Value lhs, Value rhs, + ConversionPatternRewriter &rewriter, int lut) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + FailureOr result = getContextualServerKey(op); + if (failed(result)) return result; + + Value serverKey = result.value(); + // A followup -cse pass should combine repeated LUT generation ops. + auto lookupTable = b.getIntegerAttr( + b.getIntegerType(kBinaryGateLutWidth, /*isSigned=*/false), lut); + auto lutOp = + b.create(serverKey, lookupTable); + // Construct input = rhs << 1 + lhs + auto shiftedRhs = b.create( + serverKey, rhs, + b.create(b.getI8Type(), b.getI8IntegerAttr(1)) + .getResult()); + auto input = b.create(serverKey, shiftedRhs, lhs); + rewriter.replaceOp( + op, b.create(serverKey, input, lutOp)); + return success(); +} + +struct ConvertAndOp : public OpConversionPattern { + ConvertAndOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::AndOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return replaceBinaryGate(op.getOperation(), adaptor.getLhs(), + adaptor.getRhs(), rewriter, kAndLut); + } +}; + +struct ConvertOrOp : public OpConversionPattern { + ConvertOrOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::OrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return replaceBinaryGate(op.getOperation(), adaptor.getLhs(), + adaptor.getRhs(), rewriter, kOrLut); + } +}; + +struct ConvertXorOp : public OpConversionPattern { + ConvertXorOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::XorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return replaceBinaryGate(op.getOperation(), adaptor.getLhs(), + adaptor.getRhs(), rewriter, kXorLut); + } +}; + +struct ConvertNotOp : public OpConversionPattern { + ConvertNotOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + cggi::NotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + FailureOr result = getContextualServerKey(op); + if (failed(result)) return result; + Value serverKey = result.value(); + + auto width = widthFromEncodingAttr(op.getInput().getType().getEncoding()); + auto cleartextType = b.getIntegerType(width); + auto outputType = encrytpedUIntTypeFromWidth(getContext(), width); + // not(x) == trivial_encryption(1) - x + auto createTrivialOp = rewriter.create( + op.getLoc(), outputType, serverKey, + b.create(cleartextType, + b.getIntegerAttr(cleartextType, 1)) + .getResult()); + rewriter.replaceOp(op, b.create( + serverKey, createTrivialOp, adaptor.getInput())); + return success(); + } +}; + struct ConvertTrivialEncryptOp : public OpConversionPattern { ConvertTrivialEncryptOp(mlir::MLIRContext *context) @@ -275,8 +405,9 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase { // FIXME: still need to update callers to insert the new server key arg, if // needed and possible. - patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(op, target, std::move(patterns)))) { return signalPassFailure(); diff --git a/tests/cggi_to_tfhe_rust/binary_gates.mlir b/tests/cggi_to_tfhe_rust/binary_gates.mlir new file mode 100644 index 000000000..89c5e274f --- /dev/null +++ b/tests/cggi_to_tfhe_rust/binary_gates.mlir @@ -0,0 +1,35 @@ +// RUN: heir-opt --cggi-to-tfhe-rust -cse %s | FileCheck %s + +#encoding = #lwe.unspecified_bit_field_encoding +!ct_ty = !lwe.lwe_ciphertext +!pt_ty = !lwe.lwe_plaintext +// CHECK-LABEL: @binary_gates +// CHECK-SAME: %[[sks:.*]]: [[sks_ty:!tfhe_rust.server_key]], %[[arg1:.*]]: [[ct_ty:!tfhe_rust.eui3]], %[[arg2:.*]]: [[ct_ty]] +func.func @binary_gates(%arg1: !ct_ty, %arg2: !ct_ty) -> (!ct_ty) { + // CHECK: %[[v0:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 8 : ui4} + // CHECK: %[[shiftAmount:.*]] = arith.constant 1 : i8 + + // CHECK: %[[v1:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[arg2]], %[[shiftAmount]] + // CHECK: %[[v2:.*]] = tfhe_rust.add %[[sks]], %[[v1]], %[[arg1]] + // CHECK: %[[v3:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v2]], %[[v0]] + %0 = cggi.and %arg1, %arg2 : !ct_ty + + // CHECK: %[[v4:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 14 : ui4} + // CHECK: %[[v5:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v2]], %[[v4]] + // (reuses shifted inputs from the AND) + %1 = cggi.or %arg1, %arg2 : !ct_ty + + // CHECK: %[[notConst:.*]] = arith.constant 1 : i3 + // CHECK: %[[v6:.*]] = tfhe_rust.create_trivial %[[sks]], %[[notConst]] + // CHECK: %[[v7:.*]] = tfhe_rust.sub %[[sks]], %[[v6]], %[[v5]] + %2 = cggi.not %1 : !ct_ty + + // CHECK: %[[v8:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 6 : ui4} + // CHECK: %[[v9:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[v3]], %[[shiftAmount]] + // CHECK: %[[v10:.*]] = tfhe_rust.add %[[sks]], %[[v9]], %[[v7]] + // CHECK: %[[v11:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v10]], %[[v8]] + %3 = cggi.xor %2, %0 : !ct_ty + + // CHECK: return %[[v11]] + return %3 : !ct_ty +}