From 3d8e5e9bceac2e6073bff2f42dc4f92b7df59b12 Mon Sep 17 00:00:00 2001 From: WoutLegiest Date: Wed, 25 Dec 2024 01:30:54 +0000 Subject: [PATCH 1/2] Working Quart to Tfhe rs + Change tfhe-rs and cggi shift ops --- .../Conversions/ArithToCGGI/ArithToCGGI.cpp | 23 +- .../ArithToCGGIQuart/ArithToCGGIQuart.cpp | 221 ++++++++++++++---- .../CGGIToTfheRust/CGGIToTfheRust.cpp | 44 ++-- lib/Dialect/CGGI/IR/CGGIOps.td | 8 +- lib/Dialect/TfheRust/IR/TfheRustOps.td | 12 +- lib/Dialect/TfheRust/IR/TfheRustTypes.td | 1 + lib/Target/TfheRust/TfheRustEmitter.cpp | 20 +- lib/Target/TfheRustHL/TfheRustHLEmitter.cpp | 8 +- .../ArithToCGGI/arith-to-cggi.mlir | 2 + .../ArithToCGGIQuart/quarter_wide.mlir | 12 +- .../Conversions/cggi_to_tfhe_rust/arith.mlir | 3 +- .../cggi_to_tfhe_rust/binary_gates.mlir | 6 +- .../TfheRust/Emitters/emit_levelled_ops.mlir | 5 +- .../TfheRust/Emitters/emit_tfhe_rust.mlir | 6 +- tests/Dialect/TfheRust/IR/ops.mlir | 3 +- .../TfheRust/Transforms/canonicalize.mlir | 10 +- .../forward_add_one.mlir | 32 +-- .../loop_unroll/full_loop_unroll.mlir | 3 +- 18 files changed, 269 insertions(+), 150 deletions(-) diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp index 8f728ee0e..d1509799c 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp @@ -139,12 +139,10 @@ struct ConvertShRUIOp : public OpConversionPattern { .getSExtValue(); auto inputValue = - mlir::IntegerAttr::get(rewriter.getI8Type(), (int8_t)shiftAmount); - auto cteOp = rewriter.create( - op.getLoc(), rewriter.getI8Type(), inputValue); + mlir::IntegerAttr::get(rewriter.getIndexType(), (int8_t)shiftAmount); - auto shiftOp = - b.create(outputType, adaptor.getLhs(), cteOp); + auto shiftOp = b.create( + outputType, adaptor.getLhs(), inputValue); rewriter.replaceOp(op, shiftOp); return success(); @@ -157,14 +155,12 @@ struct ConvertShRUIOp : public OpConversionPattern { auto shiftAmount = cast(cteShiftSizeOp.getValue()).getValue().getSExtValue(); - auto inputValue = mlir::IntegerAttr::get(rewriter.getI8Type(), shiftAmount); - auto cteOp = rewriter.create( - op.getLoc(), rewriter.getI8Type(), inputValue); + auto inputValue = + mlir::IntegerAttr::get(rewriter.getIndexType(), shiftAmount); - auto shiftOp = - b.create(outputType, adaptor.getLhs(), cteOp); + auto shiftOp = b.create( + outputType, adaptor.getLhs(), inputValue); rewriter.replaceOp(op, shiftOp); - rewriter.replaceOp(op.getLhs().getDefiningOp(), cteOp); return success(); } @@ -184,10 +180,7 @@ struct ArithToCGGI : public impl::ArithToCGGIBase { target.addDynamicallyLegalOp( [](mlir::arith::ConstantOp op) { // Allow use of constant if it is used to denote the size of a shift - bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) { - return isa(user); - }); - return (isa(op.getValue().getType()) || (usedByShift)); + return (isa(op.getValue().getType())); }); target.addDynamicallyLegalOp< diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp index 343bde13b..ae7adc819 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp @@ -1,9 +1,5 @@ #include "lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.h" -#include - -#include - #include "lib/Dialect/CGGI/IR/CGGIDialect.h" #include "lib/Dialect/CGGI/IR/CGGIOps.h" #include "lib/Dialect/LWE/IR/LWEOps.h" @@ -15,7 +11,9 @@ #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project namespace mlir::heir::arith { @@ -94,7 +92,7 @@ class ArithToCGGIQuartTypeConverter : public TypeConverter { }; static Value createTrivialOpMaxWidth(ImplicitLocOpBuilder b, int value) { - auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth >> 1); + auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth); auto intAttr = b.getIntegerAttr(maxWideIntType, value); auto encoding = @@ -153,19 +151,16 @@ static SmallVector extractLastDimHalves( static Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, int64_t value) { - unsigned elementBitWidth = 0; - if (auto lweTy = dyn_cast(type)) - elementBitWidth = - cast(lweTy.getEncoding()) - .getCleartextBitwidth(); - else - elementBitWidth = maxIntWidth; + // unsigned elementBitWidth = 0; + // if (auto lweTy = dyn_cast(type)) + // elementBitWidth = + // cast(lweTy.getEncoding()) + // .getCleartextBitwidth(); + // else + // elementBitWidth = maxIntWidth; - auto apValue = APInt(elementBitWidth, value); - - auto maxWideIntType = - IntegerType::get(builder.getContext(), maxIntWidth >> 1); - auto intAttr = builder.getIntegerAttr(maxWideIntType, value); + auto intAttr = builder.getIntegerAttr( + IntegerType::get(builder.getContext(), maxIntWidth), value); return builder.create(loc, type, intAttr); } @@ -249,6 +244,40 @@ struct ConvertQuartConstantOp } }; +struct ConvertQuartTruncIOp + : public OpConversionPattern { + ConvertQuartTruncIOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto newResultTy = getTypeConverter()->convertType( + op.getResult().getType()); + auto newInTy = + getTypeConverter()->convertType(op.getIn().getType()); + + SmallVector offsets(newResultTy.getShape().size(), + rewriter.getIndexAttr(0)); + offsets.back() = rewriter.getIndexAttr(newInTy.getShape().back() - + newResultTy.getShape().back()); + SmallVector sizes(newResultTy.getShape().size()); + sizes.back() = rewriter.getIndexAttr(1); + SmallVector strides(newResultTy.getShape().size(), + rewriter.getIndexAttr(1)); + + auto resOp = rewriter.create( + op->getLoc(), adaptor.getIn(), offsets, sizes, strides); + rewriter.replaceOp(op, resOp); + + return success(); + } +}; + template struct ConvertQuartExt final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -274,23 +303,21 @@ struct ConvertQuartExt final : OpConversionPattern { auto resultChunks = newResultTy.getShape().back(); auto inChunks = newInTy.getShape().back(); - if (resultChunks > inChunks) { - auto paddingFactor = resultChunks - inChunks; + // Through definition of ExtOp, paddingFactor is always positive + auto paddingFactor = resultChunks - inChunks; - SmallVector low, high; - low.push_back(rewriter.getIndexAttr(0)); - high.push_back(rewriter.getIndexAttr(paddingFactor)); + SmallVector low, high; + low.push_back(rewriter.getIndexAttr(0)); + high.push_back(rewriter.getIndexAttr(paddingFactor)); - auto padValue = createTrivialOpMaxWidth(b, 0); + auto padValue = createTrivialOpMaxWidth(b, 0); - auto resultVec = b.create(newResultTy, adaptor.getIn(), - low, high, padValue, - /*nofold=*/true); + auto resultVec = b.create(newResultTy, adaptor.getIn(), low, + high, padValue, + /*nofold=*/true); - rewriter.replaceOp(op, resultVec); - return success(); - } - return failure(); + rewriter.replaceOp(op, resultVec); + return success(); } }; @@ -318,14 +345,15 @@ struct ConvertQuartAddI final : OpConversionPattern { // Actual type of the underlying elements; we use half the width. // Create Constant - auto intAttr = IntegerAttr::get(rewriter.getI8Type(), maxIntWidth >> 1); + auto shiftAttr = + IntegerAttr::get(rewriter.getIndexType(), maxIntWidth >> 1); auto elemType = convertArithToCGGIType( IntegerType::get(op->getContext(), maxIntWidth), op->getContext()); auto realTy = convertArithToCGGIType( IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext()); - auto constantOp = b.create(intAttr); + // auto constantOp = b.create(intAttr); SmallVector carries; SmallVector outputs; @@ -338,7 +366,8 @@ struct ConvertQuartAddI final : OpConversionPattern { // Now all the outputs are 16b elements, wants presentation of 4x8b if (i != splitLhs.size() - 1) { - auto carry = b.create(elemType, lowSum, constantOp); + auto carry = + b.create(elemType, lowSum, shiftAttr); carries.push_back(carry); } @@ -356,6 +385,103 @@ struct ConvertQuartAddI final : OpConversionPattern { } }; +// Implemented using the Karatsuba algorithm +// https://en.wikipedia.org/wiki/Karatsuba_algorithm#Algorithm +struct ConvertQuartMulI final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::MulIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + auto newTy = + getTypeConverter()->convertType(op.getType()); + if (!newTy) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unsupported type: {0}", op.getType())); + if (newTy.getShape().back() != 4) + return rewriter.notifyMatchFailure( + loc, llvm::formatv("Mul only support 4 split elements. Shape: {0}", + newTy)); + + auto elemTy = convertArithToCGGIType( + IntegerType::get(op->getContext(), maxIntWidth), op->getContext()); + auto realTy = convertArithToCGGIType( + IntegerType::get(op->getContext(), maxIntWidth >> 1), op->getContext()); + + // Create Constant + auto shiftAttr = + rewriter.getIntegerAttr(b.getIndexType(), maxIntWidth >> 1); + + SmallVector splitLhs = + extractLastDimHalves(rewriter, loc, adaptor.getLhs()); + SmallVector splitRhs = + extractLastDimHalves(rewriter, loc, adaptor.getRhs()); + + // TODO: Implement the real Karatsuba algorithm for 4x4 multiplication. + // First part of Karatsuba algorithm + auto z00 = b.create(splitLhs[0], splitRhs[0]); + auto z02 = b.create(splitLhs[1], splitRhs[1]); + auto z01_p1 = b.create(splitLhs[0], splitLhs[1]); + auto z01_p2 = b.create(splitRhs[0], splitRhs[1]); + auto z01_m = b.create(z01_p1, z01_p2); + auto z01_s = b.create(z01_m, z00); + auto z01 = b.create(z01_s, z02); + + // Second part I of Karatsuba algorithm + auto z1a0 = b.create(splitLhs[0], splitRhs[2]); + auto z1a2 = b.create(splitLhs[1], splitRhs[3]); + auto z1a1_p1 = b.create(splitLhs[0], splitLhs[1]); + auto z1a1_p2 = b.create(splitRhs[2], splitRhs[3]); + auto z1a1_m = b.create(z1a1_p1, z1a1_p2); + auto z1a1_s = b.create(z1a1_m, z1a0); + auto z1a1 = b.create(z1a1_s, z1a2); + + // Second part II of Karatsuba algorithm + auto z1b0 = b.create(splitLhs[2], splitRhs[0]); + auto z1b2 = b.create(splitLhs[3], splitRhs[1]); + auto z1b1_p1 = b.create(splitLhs[2], splitLhs[3]); + auto z1b1_p2 = b.create(splitRhs[0], splitRhs[1]); + auto z1b1_m = b.create(z1b1_p1, z1b1_p2); + auto z1b1_s = b.create(z1b1_m, z1b0); + auto z1b1 = b.create(z1b1_s, z1b2); + + auto out2Kara = b.create(z1a0, z1b0); + auto out2Carry = b.create(out2Kara, z02); + auto out3Carry = b.create(z1a1, z1b1); + + // Output are now all 16b elements, wants presentation of 4x8b + auto output0Lsb = b.create(realTy, z00); + auto output0LsbHigh = b.create(elemTy, output0Lsb); + auto output0Msb = + b.create(elemTy, z00, shiftAttr); + + auto output1Lsb = b.create(realTy, z01); + auto output1LsbHigh = b.create(elemTy, output1Lsb); + auto output1Msb = + b.create(elemTy, z01, shiftAttr); + + auto output2Lsb = b.create(realTy, out2Carry); + auto output2LsbHigh = b.create(elemTy, output2Lsb); + auto output2Msb = + b.create(elemTy, out2Carry, shiftAttr); + + auto output3Lsb = b.create(realTy, out3Carry); + auto output3LsbHigh = b.create(elemTy, output3Lsb); + + auto output1 = b.create(output1LsbHigh, output0Msb); + auto output2 = b.create(output2LsbHigh, output1Msb); + auto output3 = b.create(output3LsbHigh, output2Msb); + + Value resultVec = constructResultTensor( + rewriter, loc, newTy, {output0LsbHigh, output1, output2, output3}); + rewriter.replaceOp(op, resultVec); + return success(); + } +}; + struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase { void runOnOperation() override { MLIRContext *context = &getContext(); @@ -386,28 +512,29 @@ struct ArithToCGGIQuart : public impl::ArithToCGGIQuartBase { target.addDynamicallyLegalOp( [](mlir::arith::ConstantOp op) { - // Allow use of constant if it is used to denote the size of a shift - bool usedByShift = llvm::any_of(op->getUsers(), [&](Operation *user) { - return isa(user); - }); - return (isa(op.getValue().getType()) || (usedByShift)); + return isa(op.getValue().getType()); }); - patterns.add< - ConvertQuartConstantOp, ConvertQuartExt, - ConvertQuartExt, ConvertQuartAddI, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny, - ConvertAny, ConvertAny>( - typeConverter, context); + patterns + .add, + ConvertQuartExt, ConvertQuartAddI, + ConvertQuartMulI, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny, ConvertAny, + ConvertAny>(typeConverter, context); addStructuralConversionPatterns(typeConverter, patterns, target); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { return signalPassFailure(); } + + // Remove the uncessary tensor ops between each converted arith operation. + OpPassManager pipeline("builtin.module"); + pipeline.addPass(createCSEPass()); + (void)runPipeline(pipeline, getOperation()); } }; diff --git a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp index a43d3a237..9705e35bc 100644 --- a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -215,13 +215,9 @@ struct ConvertLut3Op : public OpConversionPattern { serverKey, adaptor.getLookupTable()); // Construct input = c << 2 + b << 1 + a auto shiftedC = b.create( - serverKey, adaptor.getC(), - b.create(b.getI8Type(), b.getI8IntegerAttr(2)) - .getResult()); + serverKey, adaptor.getC(), b.getIndexAttr(2)); auto shiftedB = b.create( - serverKey, adaptor.getB(), - b.create(b.getI8Type(), b.getI8IntegerAttr(1)) - .getResult()); + serverKey, adaptor.getB(), b.getIndexAttr(1)); auto summedBC = b.create(serverKey, shiftedC, shiftedB); auto summedABC = b.create(serverKey, summedBC, adaptor.getA()); @@ -251,9 +247,7 @@ struct ConvertLut2Op : public OpConversionPattern { serverKey, adaptor.getLookupTable()); // Construct input = b << 1 + a auto shiftedB = b.create( - serverKey, adaptor.getB(), - b.create(b.getI8Type(), b.getI8IntegerAttr(1)) - .getResult()); + serverKey, adaptor.getB(), b.getIndexAttr(1)); auto summedBA = b.create(serverKey, shiftedB, adaptor.getA()); @@ -277,10 +271,8 @@ static LogicalResult replaceBinaryGate(Operation *op, Value lhs, Value rhs, auto lutOp = b.create(serverKey, lookupTable); // Construct input = rhs << 1 + lhs - auto shiftedRhs = b.create( - serverKey, rhs, - b.create(b.getI8Type(), b.getI8IntegerAttr(1)) - .getResult()); + auto shiftedRhs = + b.create(serverKey, rhs, b.getIndexAttr(1)); auto input = b.create(serverKey, shiftedRhs, lhs); rewriter.replaceOp( op, b.create(serverKey, input, lutOp)); @@ -348,14 +340,14 @@ struct ConvertXorOp : public OpConversionPattern { } }; -struct ConvertShROp : public OpConversionPattern { +struct ConvertShROp : public OpConversionPattern { ConvertShROp(mlir::MLIRContext *context) - : OpConversionPattern(context) {} + : OpConversionPattern(context) {} using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - cggi::ShiftRightOp op, OpAdaptor adaptor, + cggi::ScalarShiftRightOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); FailureOr result = getContextualServerKey(op); @@ -479,10 +471,17 @@ struct ConvertTrivialOp : public OpConversionPattern { auto constantWidth = op.getValue().getValue().getBitWidth(); auto cteOp = rewriter.create( - op.getLoc(), rewriter.getIntegerType(constantWidth), inputValue); + op.getLoc(), op.getValue().getType(), inputValue); auto outputType = encrytpedUIntTypeFromWidth(getContext(), constantWidth); + if (isa(op.getResult().getType())) { + auto elemOutputType = + encrytpedUIntTypeFromWidth(getContext(), constantWidth); + auto shape = cast(op.getResult().getType()).getShape(); + outputType = RankedTensorType::get(shape, elemOutputType); + } + auto createTrivialOp = rewriter.create( op.getLoc(), outputType, serverKey, cteOp); rewriter.replaceOp(op, createTrivialOp); @@ -538,11 +537,11 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase { target.addDynamicallyLegalOp< memref::AllocOp, memref::DeallocOp, memref::StoreOp, memref::LoadOp, memref::SubViewOp, memref::CopyOp, affine::AffineLoadOp, - affine::AffineStoreOp, tensor::FromElementsOp, tensor::ExtractOp>( - [&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()) && - typeConverter.isLegal(op->getResultTypes()); - }); + tensor::InsertOp, tensor::InsertSliceOp, affine::AffineStoreOp, + tensor::FromElementsOp, tensor::ExtractOp>([&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); // FIXME: still need to update callers to insert the new server key arg, if // needed and possible. @@ -556,6 +555,7 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase { ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, + ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny, ConvertAny>( typeConverter, context); diff --git a/lib/Dialect/CGGI/IR/CGGIOps.td b/lib/Dialect/CGGI/IR/CGGIOps.td index c9c7e38d8..f1ccc04b9 100644 --- a/lib/Dialect/CGGI/IR/CGGIOps.td +++ b/lib/Dialect/CGGI/IR/CGGIOps.td @@ -295,18 +295,18 @@ def CGGI_SubOp : CGGI_Op<"sub", [ } -def CGGI_ShiftRightOp : CGGI_Op<"shr", [ +def CGGI_ScalarShiftRightOp : CGGI_Op<"sshr", [ Pure, ]> { - let arguments = (ins LWECiphertextLike:$lhs, AnyI8:$shiftAmount); + let arguments = (ins LWECiphertextLike:$lhs, IndexAttr:$shiftAmount); let results = (outs LWECiphertextLike:$output); let summary = "Arithmetic shift to the right of a ciphertext by an integer. Note this operations to mirror the TFHE-rs implmementation."; } -def CGGI_ShiftLeftOp : CGGI_Op<"shl", [ +def CGGI_ScalarShiftLeftOp : CGGI_Op<"sshl", [ Pure ]> { - let arguments = (ins LWECiphertextLike:$lhs, AnyI8:$shiftAmount); + let arguments = (ins LWECiphertextLike:$lhs, IndexAttr:$shiftAmount); let results = (outs LWECiphertextLike:$output); let summary = "Arithmetic shift to left of a ciphertext by an integer. Note this operations to mirror the TFHE-rs implmementation."; } diff --git a/lib/Dialect/TfheRust/IR/TfheRustOps.td b/lib/Dialect/TfheRust/IR/TfheRustOps.td index 43c271462..9c28875bf 100644 --- a/lib/Dialect/TfheRust/IR/TfheRustOps.td +++ b/lib/Dialect/TfheRust/IR/TfheRustOps.td @@ -24,15 +24,15 @@ class TfheRust_BinaryOp ]> { let arguments = (ins TfheRust_ServerKey:$serverKey, - TfheRust_CiphertextType:$lhs, - TfheRust_CiphertextType:$rhs + TfheRust_CiphertextLikeType:$lhs, + TfheRust_CiphertextLikeType:$rhs ); - let results = (outs TfheRust_CiphertextType:$output); + let results = (outs TfheRust_CiphertextLikeType:$output); } def TfheRust_CreateTrivialOp : TfheRust_Op<"create_trivial", [Pure]> { let arguments = (ins TfheRust_ServerKey:$serverKey, AnyInteger:$value); - let results = (outs TfheRust_CiphertextType:$output); + let results = (outs TfheRust_CiphertextLikeType:$output); let hasCanonicalizer = 1; } @@ -49,7 +49,7 @@ def TfheRust_ScalarLeftShiftOp : TfheRust_Op<"scalar_left_shift", [ let arguments = (ins TfheRust_ServerKey:$serverKey, TfheRust_CiphertextType:$ciphertext, - AnyI8:$shiftAmount + IndexAttr:$shiftAmount ); let results = (outs TfheRust_CiphertextType:$output); } @@ -61,7 +61,7 @@ def TfheRust_ScalarRightShiftOp : TfheRust_Op<"scalar_right_shift", [ let arguments = (ins TfheRust_ServerKey:$serverKey, TfheRust_CiphertextType:$ciphertext, - AnyI8:$shiftAmount + IndexAttr:$shiftAmount ); let results = (outs TfheRust_CiphertextType:$output); } diff --git a/lib/Dialect/TfheRust/IR/TfheRustTypes.td b/lib/Dialect/TfheRust/IR/TfheRustTypes.td index 6ab2f6a6b..53bec97dd 100644 --- a/lib/Dialect/TfheRust/IR/TfheRustTypes.td +++ b/lib/Dialect/TfheRust/IR/TfheRustTypes.td @@ -65,6 +65,7 @@ def TfheRust_CiphertextType : TfheRust_EncryptedInt256, ]>; +def TfheRust_CiphertextLikeType : TypeOrContainer; def TfheRust_ServerKey : TfheRust_Type<"ServerKey", "server_key", [PassByReference]> { let summary = "The short int server key required to perform homomorphic operations."; diff --git a/lib/Target/TfheRust/TfheRustEmitter.cpp b/lib/Target/TfheRust/TfheRustEmitter.cpp index 6fd1495f0..bb346bb1e 100644 --- a/lib/Target/TfheRust/TfheRustEmitter.cpp +++ b/lib/Target/TfheRust/TfheRustEmitter.cpp @@ -492,11 +492,7 @@ std::string TfheRustEmitter::operationType(Operation *op) { "\")"; }) .Case([&](ScalarLeftShiftOp op) { - auto constantShift = - cast(op.getShiftAmount().getDefiningOp()); - return "LSH(" + - std::to_string( - cast(constantShift.getValue()).getInt()) + + return "LSH(" + std::to_string(op.getShiftAmount().getSExtValue()) + ")"; }) .Case([&](Operation *) { return "ADD"; }); @@ -518,9 +514,17 @@ LogicalResult TfheRustEmitter::printOperation(affine::AffineForOp forOp) { } LogicalResult TfheRustEmitter::printOperation(ScalarLeftShiftOp op) { - return printSksMethod(op.getResult(), op.getServerKey(), - {op.getCiphertext(), op.getShiftAmount()}, - "scalar_left_shift", {"", "u8"}); + emitAssignPrefix(op.getResult()); + os << variableNames->getNameForValue(op.getServerKey()) + << ".scalar_left_shift("; + + auto valueStr = variableNames->getNameForValue(op.getCiphertext()); + std::string prefix = + op.getCiphertext().getType().hasTrait() ? "&" : ""; + auto cipherString = prefix + valueStr; + + os << cipherString << ", " << op.getShiftAmount() << " as u8);\n"; + return success(); } LogicalResult TfheRustEmitter::printOperation(CreateTrivialOp op) { diff --git a/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp b/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp index 9c9ea7353..de8e3d67b 100644 --- a/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp +++ b/lib/Target/TfheRustHL/TfheRustHLEmitter.cpp @@ -561,8 +561,12 @@ LogicalResult TfheRustHLEmitter::printOperation(SubOp op) { } LogicalResult TfheRustHLEmitter::printOperation(ScalarRightShiftOp op) { - return printBinaryOp(op.getResult(), op.getCiphertext(), op.getShiftAmount(), - ">>"); + emitAssignPrefix(op.getResult()); + + os << checkOrigin(op.getCiphertext()) + << variableNames->getNameForValue(op.getCiphertext()) << " >> " + << op.getShiftAmount() << "u8;\n"; + return success(); } LogicalResult TfheRustHLEmitter::printOperation(CastOp op) { diff --git a/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir b/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir index dfc88b26a..01fb3fe10 100644 --- a/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir +++ b/tests/Dialect/Arith/Conversions/ArithToCGGI/arith-to-cggi.mlir @@ -66,6 +66,8 @@ func.func @test_affine(%arg0: memref<1x1xi32>) -> memref<1x1xi32> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1xi32> %25 = arith.muli %0, %c33_i8 : i32 %26 = arith.addi %c429_i32, %25 : i32 + %c2 = arith.constant 2 : i32 + %27 = arith.shrui %26, %c2 : i32 affine.store %26, %alloc[0, 0] : memref<1x1xi32> return %alloc : memref<1x1xi32> } diff --git a/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir b/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir index 363f50da9..f72e9243a 100644 --- a/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir +++ b/tests/Dialect/Arith/Conversions/ArithToCGGIQuart/quarter_wide.mlir @@ -1,10 +1,10 @@ // RUN: heir-opt --arith-to-cggi-quart %s | FileCheck %s // CHECK: return %[[RET:.*]] tensor<4x!lwe.lwe_ciphertext> -func.func @test_simple_split2(%arg0: i32, %arg1: i16) -> i32 { - %2 = arith.constant 31 : i16 - %5 = arith.addi %arg1, %2 : i16 - %6 = arith.extui %5 : i16 to i32 - %7 = arith.addi %arg0, %6 : i32 - return %6 : i32 +func.func @test_simple_split2(%arg0: i32, %arg1: i32) -> i32 { + %2 = arith.constant 31 : i8 + %1 = arith.extui %2 : i8 to i32 + %5 = arith.addi %arg1, %1 : i32 + %7 = arith.muli %arg0, %5 : i32 + return %7 : i32 } diff --git a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir index 226ce55e5..13c8cbe78 100644 --- a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir +++ b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/arith.mlir @@ -14,6 +14,7 @@ func.func @test_affine(%arg0: memref<1x1x!ct_ty>) -> memref<1x1x!ct_ty> { %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x1x!ct_ty> %3 = cggi.mul %2, %1 : !ct_ty %4 = cggi.add %3, %0 : !ct_ty - affine.store %4, %alloc[0, 0] : memref<1x1x!ct_ty> + %5 = cggi.sshr %4 {shiftAmount = 2 : index} : (!ct_ty) -> !ct_ty + affine.store %5, %alloc[0, 0] : memref<1x1x!ct_ty> return %alloc : memref<1x1x!ct_ty> } diff --git a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/binary_gates.mlir b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/binary_gates.mlir index 89c5e274f..b16c93112 100644 --- a/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/binary_gates.mlir +++ b/tests/Dialect/CGGI/Conversions/cggi_to_tfhe_rust/binary_gates.mlir @@ -7,9 +7,7 @@ // CHECK-SAME: %[[sks:.*]]: [[sks_ty:!tfhe_rust.server_key]], %[[arg1:.*]]: [[ct_ty:!tfhe_rust.eui3]], %[[arg2:.*]]: [[ct_ty]] func.func @binary_gates(%arg1: !ct_ty, %arg2: !ct_ty) -> (!ct_ty) { // CHECK: %[[v0:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 8 : ui4} - // CHECK: %[[shiftAmount:.*]] = arith.constant 1 : i8 - - // CHECK: %[[v1:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[arg2]], %[[shiftAmount]] + // CHECK: %[[v1:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[arg2]] {shiftAmount = 1 : index} // CHECK: %[[v2:.*]] = tfhe_rust.add %[[sks]], %[[v1]], %[[arg1]] // CHECK: %[[v3:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v2]], %[[v0]] %0 = cggi.and %arg1, %arg2 : !ct_ty @@ -25,7 +23,7 @@ func.func @binary_gates(%arg1: !ct_ty, %arg2: !ct_ty) -> (!ct_ty) { %2 = cggi.not %1 : !ct_ty // CHECK: %[[v8:.*]] = tfhe_rust.generate_lookup_table %[[sks]] {truthTable = 6 : ui4} - // CHECK: %[[v9:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[v3]], %[[shiftAmount]] + // CHECK: %[[v9:.*]] = tfhe_rust.scalar_left_shift %[[sks]], %[[v3]] {shiftAmount = 1 : index} // CHECK: %[[v10:.*]] = tfhe_rust.add %[[sks]], %[[v9]], %[[v7]] // CHECK: %[[v11:.*]] = tfhe_rust.apply_lookup_table %[[sks]], %[[v10]], %[[v8]] %3 = cggi.xor %2, %0 : !ct_ty diff --git a/tests/Dialect/TfheRust/Emitters/emit_levelled_ops.mlir b/tests/Dialect/TfheRust/Emitters/emit_levelled_ops.mlir index 62093f860..ca3895a3b 100644 --- a/tests/Dialect/TfheRust/Emitters/emit_levelled_ops.mlir +++ b/tests/Dialect/TfheRust/Emitters/emit_levelled_ops.mlir @@ -14,11 +14,10 @@ // CHECK: temp_nodes[ // CHECK-NEXT: } func.func @test_levelled_op(%sks : !sks, %lut: !lut, %input1 : !eui3, %input2 : !eui3) -> !eui3 { - %c1 = arith.constant 1 : i8 %v0 = tfhe_rust.apply_lookup_table %sks, %input1, %lut : (!sks, !eui3, !lut) -> !eui3 %v1 = tfhe_rust.apply_lookup_table %sks, %input2, %lut : (!sks, !eui3, !lut) -> !eui3 %v2 = tfhe_rust.add %sks, %v0, %v1 : (!sks, !eui3, !eui3) -> !eui3 - %v3 = tfhe_rust.scalar_left_shift %sks, %v2, %c1 : (!sks, !eui3, i8) -> !eui3 + %v3 = tfhe_rust.scalar_left_shift %sks, %v2 {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3 %v4 = tfhe_rust.apply_lookup_table %sks, %v3, %lut : (!sks, !eui3, !lut) -> !eui3 return %v4 : !eui3 } @@ -44,7 +43,7 @@ func.func @test_levelled_op_break(%sks : !sks, %lut: !lut, %input1 : !eui3, %inp %v1 = tfhe_rust.apply_lookup_table %sks, %input2, %lut : (!sks, !eui3, !lut) -> !eui3 %v2 = tfhe_rust.add %sks, %v0, %v1 : (!sks, !eui3, !eui3) -> !eui3 %c1 = arith.constant 1 : i8 - %v3 = tfhe_rust.scalar_left_shift %sks, %v2, %c1 : (!sks, !eui3, i8) -> !eui3 + %v3 = tfhe_rust.scalar_left_shift %sks, %v2 {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3 %v4 = tfhe_rust.apply_lookup_table %sks, %v3, %lut : (!sks, !eui3, !lut) -> !eui3 return %v4 : !eui3 } diff --git a/tests/Dialect/TfheRust/Emitters/emit_tfhe_rust.mlir b/tests/Dialect/TfheRust/Emitters/emit_tfhe_rust.mlir index 52b786c03..a9793af95 100644 --- a/tests/Dialect/TfheRust/Emitters/emit_tfhe_rust.mlir +++ b/tests/Dialect/TfheRust/Emitters/emit_tfhe_rust.mlir @@ -38,16 +38,14 @@ func.func @test_apply_lookup_table(%sks : !sks, %lut: !lut, %input : !eui3) -> ! // CHECK-NEXT: ) -> Ciphertext { // CHECK: let [[v1:.*]] = [[sks]].apply_lookup_table(&[[input]], &[[lut]]); // CHECK: let [[v2:.*]] = [[sks]].unchecked_add(&[[input]], &[[v1]]); -// CHECK: let [[c1:.*]] = 1; -// CHECK: let [[v3:.*]] = [[sks]].scalar_left_shift(&[[v2]], [[c1]] as u8); +// CHECK: let [[v3:.*]] = [[sks]].scalar_left_shift(&[[v2]], [[c1:.*]] as u8); // CHECK: let [[v4:.*]] = [[sks]].apply_lookup_table(&[[v3]], &[[lut]]); // CHECK-NEXT: [[v4]] // CHECK-NEXT: } func.func @test_apply_lookup_table2(%sks : !sks, %lut: !lut, %input : !eui3) -> !eui3 { %v1 = tfhe_rust.apply_lookup_table %sks, %input, %lut : (!sks, !eui3, !lut) -> !eui3 %v2 = tfhe_rust.add %sks, %input, %v1 : (!sks, !eui3, !eui3) -> !eui3 - %c1 = arith.constant 1 : i8 - %v3 = tfhe_rust.scalar_left_shift %sks, %v2, %c1 : (!sks, !eui3, i8) -> !eui3 + %v3 = tfhe_rust.scalar_left_shift %sks, %v2 {shiftAmount = 1 : index} : (!sks, !eui3) -> !eui3 %v4 = tfhe_rust.apply_lookup_table %sks, %v3, %lut : (!sks, !eui3, !lut) -> !eui3 return %v4 : !eui3 } diff --git a/tests/Dialect/TfheRust/IR/ops.mlir b/tests/Dialect/TfheRust/IR/ops.mlir index a7284621b..9283e6ca2 100644 --- a/tests/Dialect/TfheRust/IR/ops.mlir +++ b/tests/Dialect/TfheRust/IR/ops.mlir @@ -38,8 +38,7 @@ module { %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 - %shiftAmount = arith.constant 1 : i8 - %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2 {shiftAmount = 1 : index} : (!sks, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 diff --git a/tests/Dialect/TfheRust/Transforms/canonicalize.mlir b/tests/Dialect/TfheRust/Transforms/canonicalize.mlir index 27941ea52..d9434b8de 100644 --- a/tests/Dialect/TfheRust/Transforms/canonicalize.mlir +++ b/tests/Dialect/TfheRust/Transforms/canonicalize.mlir @@ -6,14 +6,11 @@ module { // CHECK-LABEL: func @test_move_create_trivial func.func @test_move_create_trivial(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 { // CHECK: arith.constant - // CHECK-NEXT: arith.constant // CHECK-NEXT: tfhe_rust.create_trivial // CHECK-NEXT: tfhe_rust.create_trivial %0 = arith.constant 1 : i3 - %1 = arith.constant 2 : i3 %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 - %shiftAmount = arith.constant 1 : i8 - %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2 {shiftAmount = 1 : index} : (!sks, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -23,19 +20,16 @@ module { // CHECK-LABEL: func @test_move_out_of_loop func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> memref<10x!tfhe_rust.eui3> { // CHECK: arith.constant - // CHECK-NEXT: arith.constant // CHECK-NEXT: tfhe_rust.create_trivial // CHECK-NEXT: tfhe_rust.create_trivial // CHECK-NEXT: memref.alloc // CHECK-NEXT: affine.for %0 = arith.constant 1 : i3 - %1 = arith.constant 2 : i3 %memref = memref.alloca() : memref<10x!tfhe_rust.eui3> affine.for %i = 0 to 10 { %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 - %shiftAmount = arith.constant 1 : i8 - %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2 {shiftAmount = 1 : index} : (!sks, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 diff --git a/tests/Transforms/forward_store_to_load/forward_add_one.mlir b/tests/Transforms/forward_store_to_load/forward_add_one.mlir index 0b0912f5b..3bfaca09f 100644 --- a/tests/Transforms/forward_store_to_load/forward_add_one.mlir +++ b/tests/Transforms/forward_store_to_load/forward_add_one.mlir @@ -32,8 +32,8 @@ module { %2 = tfhe_rust.create_trivial %arg0, %false : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %3 = tfhe_rust.create_trivial %arg0, %0 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %4 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 8 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table - %5 = tfhe_rust.scalar_left_shift %arg0, %2, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %6 = tfhe_rust.scalar_left_shift %arg0, %3, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %5 = tfhe_rust.scalar_left_shift %arg0, %2 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %6 = tfhe_rust.scalar_left_shift %arg0, %3 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %7 = tfhe_rust.add %arg0, %5, %6 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %8 = tfhe_rust.add %arg0, %7, %1 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %9 = tfhe_rust.apply_lookup_table %arg0, %8, %4 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -41,8 +41,8 @@ module { %11 = memref.load %arg1[%c1] : memref<8x!tfhe_rust.eui3> %12 = tfhe_rust.create_trivial %arg0, %10 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %13 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 150 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table - %14 = tfhe_rust.scalar_left_shift %arg0, %12, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %15 = tfhe_rust.scalar_left_shift %arg0, %11, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %14 = tfhe_rust.scalar_left_shift %arg0, %12 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %15 = tfhe_rust.scalar_left_shift %arg0, %11 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %16 = tfhe_rust.add %arg0, %14, %15 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %17 = tfhe_rust.add %arg0, %16, %9 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %18 = tfhe_rust.apply_lookup_table %arg0, %17, %13 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -52,32 +52,32 @@ module { %22 = memref.load %arg1[%c2] : memref<8x!tfhe_rust.eui3> %23 = tfhe_rust.create_trivial %arg0, %21 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %24 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 43 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table - %25 = tfhe_rust.scalar_left_shift %arg0, %23, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %26 = tfhe_rust.scalar_left_shift %arg0, %22, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %25 = tfhe_rust.scalar_left_shift %arg0, %23 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %26 = tfhe_rust.scalar_left_shift %arg0, %22 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %27 = tfhe_rust.add %arg0, %25, %26 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %28 = tfhe_rust.add %arg0, %27, %20 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %29 = tfhe_rust.apply_lookup_table %arg0, %28, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 %30 = memref.load %alloc[%c3] : memref<8xi1> %31 = memref.load %arg1[%c3] : memref<8x!tfhe_rust.eui3> %32 = tfhe_rust.create_trivial %arg0, %30 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 - %33 = tfhe_rust.scalar_left_shift %arg0, %32, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %34 = tfhe_rust.scalar_left_shift %arg0, %31, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %33 = tfhe_rust.scalar_left_shift %arg0, %32 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %34 = tfhe_rust.scalar_left_shift %arg0, %31 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %35 = tfhe_rust.add %arg0, %33, %34 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %36 = tfhe_rust.add %arg0, %35, %29 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %37 = tfhe_rust.apply_lookup_table %arg0, %36, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 %38 = memref.load %alloc[%c4] : memref<8xi1> %39 = memref.load %arg1[%c4] : memref<8x!tfhe_rust.eui3> %40 = tfhe_rust.create_trivial %arg0, %38 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 - %41 = tfhe_rust.scalar_left_shift %arg0, %40, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %42 = tfhe_rust.scalar_left_shift %arg0, %39, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %41 = tfhe_rust.scalar_left_shift %arg0, %40 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %42 = tfhe_rust.scalar_left_shift %arg0, %39 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %43 = tfhe_rust.add %arg0, %41, %42 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %44 = tfhe_rust.add %arg0, %43, %37 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %45 = tfhe_rust.apply_lookup_table %arg0, %44, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 %46 = memref.load %alloc[%c5] : memref<8xi1> %47 = memref.load %arg1[%c5] : memref<8x!tfhe_rust.eui3> %48 = tfhe_rust.create_trivial %arg0, %46 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 - %49 = tfhe_rust.scalar_left_shift %arg0, %48, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %50 = tfhe_rust.scalar_left_shift %arg0, %47, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %49 = tfhe_rust.scalar_left_shift %arg0, %48 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %50 = tfhe_rust.scalar_left_shift %arg0, %47 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %51 = tfhe_rust.add %arg0, %49, %50 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %52 = tfhe_rust.add %arg0, %51, %45 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %53 = tfhe_rust.apply_lookup_table %arg0, %52, %24 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -85,8 +85,8 @@ module { %55 = memref.load %arg1[%c6] : memref<8x!tfhe_rust.eui3> %56 = tfhe_rust.create_trivial %arg0, %54 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 %57 = tfhe_rust.generate_lookup_table %arg0 {truthTable = 105 : ui8} : (!tfhe_rust.server_key) -> !tfhe_rust.lookup_table - %58 = tfhe_rust.scalar_left_shift %arg0, %56, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %59 = tfhe_rust.scalar_left_shift %arg0, %55, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %58 = tfhe_rust.scalar_left_shift %arg0, %56 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %59 = tfhe_rust.scalar_left_shift %arg0, %55 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %60 = tfhe_rust.add %arg0, %58, %59 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %61 = tfhe_rust.add %arg0, %60, %53 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %62 = tfhe_rust.apply_lookup_table %arg0, %61, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 @@ -94,8 +94,8 @@ module { %64 = memref.load %alloc[%c7] : memref<8xi1> %65 = memref.load %arg1[%c7] : memref<8x!tfhe_rust.eui3> %66 = tfhe_rust.create_trivial %arg0, %64 : (!tfhe_rust.server_key, i1) -> !tfhe_rust.eui3 - %67 = tfhe_rust.scalar_left_shift %arg0, %66, %c2_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 - %68 = tfhe_rust.scalar_left_shift %arg0, %65, %c1_i8 : (!tfhe_rust.server_key, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %67 = tfhe_rust.scalar_left_shift %arg0, %66 {shiftAmount = 2 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 + %68 = tfhe_rust.scalar_left_shift %arg0, %65 {shiftAmount = 1 : index} : (!tfhe_rust.server_key, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %69 = tfhe_rust.add %arg0, %67, %68 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %70 = tfhe_rust.add %arg0, %69, %63 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %71 = tfhe_rust.apply_lookup_table %arg0, %70, %57 : (!tfhe_rust.server_key, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 diff --git a/tests/Transforms/loop_unroll/full_loop_unroll.mlir b/tests/Transforms/loop_unroll/full_loop_unroll.mlir index 4170b2322..ad4eaa5f3 100644 --- a/tests/Transforms/loop_unroll/full_loop_unroll.mlir +++ b/tests/Transforms/loop_unroll/full_loop_unroll.mlir @@ -11,8 +11,7 @@ func.func @test_move_out_of_loop(%sks : !sks, %lut: !tfhe_rust.lookup_table) -> affine.for %i = 0 to 10 { %e2 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 - %shiftAmount = arith.constant 1 : i8 - %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2, %shiftAmount : (!sks, !tfhe_rust.eui3, i8) -> !tfhe_rust.eui3 + %e2Shifted = tfhe_rust.scalar_left_shift %sks, %e2 {shiftAmount = 1 : index} : (!sks, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %e1 = tfhe_rust.create_trivial %sks, %0 : (!sks, i3) -> !tfhe_rust.eui3 %eCombined = tfhe_rust.add %sks, %e1, %e2Shifted : (!sks, !tfhe_rust.eui3, !tfhe_rust.eui3) -> !tfhe_rust.eui3 %out = tfhe_rust.apply_lookup_table %sks, %eCombined, %lut : (!sks, !tfhe_rust.eui3, !tfhe_rust.lookup_table) -> !tfhe_rust.eui3 From 8dbee17a7b06921e9fbf4a5cbdcb6713dde6dbb9 Mon Sep 17 00:00:00 2001 From: WoutLegiest Date: Wed, 25 Dec 2024 01:30:54 +0000 Subject: [PATCH 2/2] Working conversion heir-tosa to HL emitter --- .../Conversions/ArithToCGGI/ArithToCGGI.cpp | 24 ++++++- .../CGGIToTfheRust/CGGIToTfheRust.cpp | 52 +++------------ lib/Target/TfheRust/Utils.cpp | 16 ++--- lib/Target/TfheRustHL/TfheRustHLEmitter.cpp | 66 +++++++++++++++++-- lib/Target/TfheRustHL/TfheRustHLEmitter.h | 3 + lib/Target/TfheRustHL/TfheRustHLTemplates.h | 3 +- lib/Utils/BUILD | 1 + lib/Utils/ConversionUtils.h | 46 +++++++++++++ .../hello_world_clean.mlir | 8 +-- 9 files changed, 152 insertions(+), 67 deletions(-) diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp index d1509799c..87c5d6323 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp @@ -118,6 +118,26 @@ struct ConvertExtUIOp : public OpConversionPattern { } }; +struct ConvertExtSIOp : public OpConversionPattern { + ConvertExtSIOp(mlir::MLIRContext *context) + : OpConversionPattern(context) {} + + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto outType = convertArithToCGGIType( + cast(op.getResult().getType()), op->getContext()); + auto castOp = b.create(op.getLoc(), outType, adaptor.getIn()); + + rewriter.replaceOp(op, castOp); + return success(); + } +}; + struct ConvertShRUIOp : public OpConversionPattern { ConvertShRUIOp(mlir::MLIRContext *context) : OpConversionPattern(context) {} @@ -192,8 +212,8 @@ struct ArithToCGGI : public impl::ArithToCGGIBase { }); patterns.add< - ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertShRUIOp, - ConvertBinOp, + ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, + ConvertShRUIOp, ConvertBinOp, ConvertBinOp, ConvertBinOp, ConvertAny, ConvertAny, diff --git a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp index 9705e35bc..7e812c805 100644 --- a/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp +++ b/lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp @@ -42,48 +42,6 @@ constexpr int kAndLut = 8; constexpr int kOrLut = 14; constexpr int kXorLut = 6; -static Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) { - // Only supporting unsigned types because the LWE dialect does not have a - // notion of signedness. - switch (width) { - case 1: - // The minimum bit width of the integer tfhe_rust API is UInt2 - // https://docs.rs/tfhe/latest/tfhe/index.html#types - // This may happen if there are no LUT or boolean gate operations that - // require a minimum bit width (e.g. shuffling bits in a program that - // multiplies by two). - LLVM_DEBUG(llvm::dbgs() - << "Upgrading ciphertext with bit width 1 to UInt2"); - [[fallthrough]]; - case 2: - return tfhe_rust::EncryptedUInt2Type::get(ctx); - case 3: - return tfhe_rust::EncryptedUInt3Type::get(ctx); - case 4: - return tfhe_rust::EncryptedUInt4Type::get(ctx); - case 8: - return tfhe_rust::EncryptedUInt8Type::get(ctx); - case 10: - return tfhe_rust::EncryptedUInt10Type::get(ctx); - case 12: - return tfhe_rust::EncryptedUInt12Type::get(ctx); - case 14: - return tfhe_rust::EncryptedUInt14Type::get(ctx); - case 16: - return tfhe_rust::EncryptedUInt16Type::get(ctx); - case 32: - return tfhe_rust::EncryptedUInt32Type::get(ctx); - case 64: - return tfhe_rust::EncryptedUInt64Type::get(ctx); - case 128: - return tfhe_rust::EncryptedUInt128Type::get(ctx); - case 256: - return tfhe_rust::EncryptedUInt256Type::get(ctx); - default: - llvm_unreachable("Unsupported bitwidth"); - } -} - class CGGIToTfheRustTypeConverter : public TypeConverter { public: CGGIToTfheRustTypeConverter(MLIRContext *ctx) { @@ -532,6 +490,12 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase { hasServerKeyArg); }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + bool hasServerKeyArg = + isa(op.getOperand(0).getType()); + return hasServerKeyArg; + }); + target.addLegalOp(); target.addDynamicallyLegalOp< @@ -546,8 +510,8 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase { // FIXME: still need to update callers to insert the new server key arg, if // needed and possible. patterns.add< - AddServerKeyArg, ConvertEncodeOp, ConvertLut2Op, ConvertLut3Op, - ConvertNotOp, ConvertTrivialEncryptOp, ConvertTrivialOp, + AddServerKeyArg, AddServerKeyArgCall, ConvertEncodeOp, ConvertLut2Op, + ConvertLut3Op, ConvertNotOp, ConvertTrivialEncryptOp, ConvertTrivialOp, ConvertCGGITRBinOp, ConvertCGGITRBinOp, ConvertCGGITRBinOp, ConvertAndOp, diff --git a/lib/Target/TfheRust/Utils.cpp b/lib/Target/TfheRust/Utils.cpp index 897d8c0c0..d1d4baee0 100644 --- a/lib/Target/TfheRust/Utils.cpp +++ b/lib/Target/TfheRust/Utils.cpp @@ -23,16 +23,16 @@ LogicalResult canEmitFuncForTfheRust(func::FuncOp &funcOp) { return TypeSwitch(op) // This list should match the list of implemented overloads of // `printOperation`. - .Case([&](auto op) { return printOperation(op); }) // Func ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) // Affine ops .Case( [&](auto op) { return printOperation(op); }) // MemRef ops - .Case( - [&](auto op) { return printOperation(op); }) + .Case([&](auto op) { return printOperation(op); }) // TfheRust ops .Case([&](auto op) { return printOperation(op); }) // Tensor ops - .Case( + .Case( [&](auto op) { return printOperation(op); }) .Default([&](Operation &) { return op.emitOpError("unable to find printer for op"); @@ -252,6 +252,25 @@ LogicalResult TfheRustHLEmitter::printOperation(func::ReturnOp op) { return success(); } +LogicalResult TfheRustHLEmitter::printOperation(func::CallOp op) { + os << "let " << variableNames->getNameForValue(op->getResult(0)) << " = "; + + os << op.getCallee() << "("; + for (Value arg : op->getOperands()) { + if (!isa(arg.getType())) { + auto argName = variableNames->getNameForValue(arg); + if (op.getOperands().back() == arg) { + os << "&" << argName; + } else { + os << "&" << argName << ", "; + } + } + } + + os << "); \n"; + return success(); +} + void TfheRustHLEmitter::emitAssignPrefix(Value result) { os << "let " << variableNames->getNameForValue(result) << " = "; } @@ -286,7 +305,9 @@ LogicalResult TfheRustHLEmitter::printMethod( LogicalResult TfheRustHLEmitter::printOperation(CreateTrivialOp op) { emitAssignPrefix(op.getResult()); - os << "FheUint" << DefaultTfheRustHLBitWidth << "::try_encrypt_trivial(" + + os << "FheUint" << getTfheRustBitWidth(op.getResult().getType()) + << "::try_encrypt_trivial(" << variableNames->getNameForValue(op.getValue()) << ").unwrap();\n"; return success(); } @@ -334,9 +355,10 @@ LogicalResult TfheRustHLEmitter::printOperation(arith::ConstantOp op) { return success(); } + // FIXME: By default, it emits an unsigned integer. emitAssignPrefix(op.getResult()); if (auto intAttr = dyn_cast(valueAttr)) { - os << intAttr.getValue() << "u64;\n"; + os << intAttr.getValue().abs() << "u64;\n"; } else { return op.emitError() << "Unknown constant type " << valueAttr.getType(); } @@ -412,6 +434,12 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::AllocOp op) { return success(); } +// Use a BTreeMap<(usize, ...), Ciphertext>. +LogicalResult TfheRustHLEmitter::printOperation(memref::DeallocOp op) { + os << variableNames->getNameForValue(op.getMemref()) << ".clear();\n"; + return success(); +} + // Store into a BTreeMap<(usize, ...), Ciphertext> LogicalResult TfheRustHLEmitter::printOperation(memref::StoreOp op) { // We assume here that the indices are SSA values (not integer attributes). @@ -544,6 +572,25 @@ LogicalResult TfheRustHLEmitter::printOperation(tensor::FromElementsOp op) { return success(); } +// Need to produce a +LogicalResult TfheRustHLEmitter::printOperation(tensor::InsertOp op) { + // emitAssignPrefix(op.getResult()); + // os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) { + // // Check if block argument, if so, clone. + // const auto *cloneStr = isa(value) ? ".clone()" : ""; + // // Get the name of defining operation its dialect + // auto tfheOp = + // value.getDefiningOp()->getDialect()->getNamespace() == + // "tfhe_rust_bool"; + // const auto *prefix = tfheOp ? "&" : ""; + // return std::string(prefix) + variableNames->getNameForValue(value) + + // cloneStr; + // }) << "];\n"; + os << "Not implemented yet\n"; + + return success(); +} + LogicalResult TfheRustHLEmitter::printOperation(BitAndOp op) { return printBinaryOp(op.getResult(), op.getLhs(), op.getRhs(), "&&"); } @@ -585,7 +632,12 @@ FailureOr TfheRustHLEmitter::convertType(Type type) { // against a specific API version. if (type.hasTrait()) { - return std::string("Ciphertext"); + auto ctxtWidth = getTfheRustBitWidth(type); + if (ctxtWidth == DefaultTfheRustHLBitWidth) { + return std::string("Ciphertext"); + } + return "tfhe::FheUint"; + ; } return llvm::TypeSwitch>(type) diff --git a/lib/Target/TfheRustHL/TfheRustHLEmitter.h b/lib/Target/TfheRustHL/TfheRustHLEmitter.h index 19c226acd..95a491a29 100644 --- a/lib/Target/TfheRustHL/TfheRustHLEmitter.h +++ b/lib/Target/TfheRustHL/TfheRustHLEmitter.h @@ -51,6 +51,7 @@ class TfheRustHLEmitter { LogicalResult printOperation(::mlir::ModuleOp op); LogicalResult printOperation(::mlir::func::FuncOp op); LogicalResult printOperation(::mlir::func::ReturnOp op); + LogicalResult printOperation(::mlir::func::CallOp op); LogicalResult printOperation(affine::AffineForOp op); LogicalResult printOperation(affine::AffineYieldOp op); LogicalResult printOperation(affine::AffineStoreOp op); @@ -63,7 +64,9 @@ class TfheRustHLEmitter { LogicalResult printOperation(arith::TruncIOp op); LogicalResult printOperation(tensor::ExtractOp op); LogicalResult printOperation(tensor::FromElementsOp op); + LogicalResult printOperation(tensor::InsertOp op); LogicalResult printOperation(memref::AllocOp op); + LogicalResult printOperation(memref::DeallocOp op); LogicalResult printOperation(memref::LoadOp op); LogicalResult printOperation(memref::StoreOp op); LogicalResult printOperation(AddOp op); diff --git a/lib/Target/TfheRustHL/TfheRustHLTemplates.h b/lib/Target/TfheRustHL/TfheRustHLTemplates.h index 23a0a3468..8fafa606f 100644 --- a/lib/Target/TfheRustHL/TfheRustHLTemplates.h +++ b/lib/Target/TfheRustHL/TfheRustHLTemplates.h @@ -9,9 +9,8 @@ namespace tfhe_rust { constexpr std::string_view kModulePrelude = R"rust( use std::collections::BTreeMap; -use tfhe::{FheUint8, FheUint16, FheUint32, FheUint64}; +use tfhe::{FheUint4, FheUint8, FheUint16, FheUint32, FheUint64}; use tfhe::prelude::*; -use tfhe::ServerKey; )rust"; } // namespace tfhe_rust diff --git a/lib/Utils/BUILD b/lib/Utils/BUILD index d74d96bb8..3ee805250 100644 --- a/lib/Utils/BUILD +++ b/lib/Utils/BUILD @@ -23,6 +23,7 @@ cc_library( "@heir//lib/Dialect/Mgmt/IR:Dialect", "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Dialect/TensorExt/IR:Dialect", + "@heir//lib/Dialect/TfheRust/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/lib/Utils/ConversionUtils.h b/lib/Utils/ConversionUtils.h index 3fdf9f7af..b5ea9dd00 100644 --- a/lib/Utils/ConversionUtils.h +++ b/lib/Utils/ConversionUtils.h @@ -13,8 +13,10 @@ #include "lib/Dialect/Mgmt/IR/MgmtOps.h" #include "lib/Dialect/Secret/IR/SecretOps.h" #include "lib/Dialect/TensorExt/IR/TensorExtOps.h" +#include "lib/Dialect/TfheRust/IR/TfheRustTypes.h" #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project @@ -33,6 +35,8 @@ #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project +#define DEBUG_TYPE "cggi-to-tfhe-rust" + namespace mlir { namespace heir { @@ -508,6 +512,48 @@ bool containsLweOrDialect(func::FuncOp func) { return walkResult.wasInterrupted(); } +inline Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) { + // Only supporting unsigned types because the LWE dialect does not have a + // notion of signedness. + switch (width) { + case 1: + // The minimum bit width of the integer tfhe_rust API is UInt2 + // https://docs.rs/tfhe/latest/tfhe/index.html#types + // This may happen if there are no LUT or boolean gate operations that + // require a minimum bit width (e.g. shuffling bits in a program that + // multiplies by two). + LLVM_DEBUG(llvm::dbgs() + << "Upgrading ciphertext with bit width 1 to UInt2"); + [[fallthrough]]; + case 2: + return tfhe_rust::EncryptedUInt2Type::get(ctx); + case 3: + return tfhe_rust::EncryptedUInt3Type::get(ctx); + case 4: + return tfhe_rust::EncryptedUInt4Type::get(ctx); + case 8: + return tfhe_rust::EncryptedUInt8Type::get(ctx); + case 10: + return tfhe_rust::EncryptedUInt10Type::get(ctx); + case 12: + return tfhe_rust::EncryptedUInt12Type::get(ctx); + case 14: + return tfhe_rust::EncryptedUInt14Type::get(ctx); + case 16: + return tfhe_rust::EncryptedUInt16Type::get(ctx); + case 32: + return tfhe_rust::EncryptedUInt32Type::get(ctx); + case 64: + return tfhe_rust::EncryptedUInt64Type::get(ctx); + case 128: + return tfhe_rust::EncryptedUInt128Type::get(ctx); + case 256: + return tfhe_rust::EncryptedUInt256Type::get(ctx); + default: + llvm_unreachable("Unsupported bitwidth"); + } +} + } // namespace heir } // namespace mlir diff --git a/tests/Transforms/tosa_to_boolean_tfhe/hello_world_clean.mlir b/tests/Transforms/tosa_to_boolean_tfhe/hello_world_clean.mlir index 9cef1e534..765b29782 100644 --- a/tests/Transforms/tosa_to_boolean_tfhe/hello_world_clean.mlir +++ b/tests/Transforms/tosa_to_boolean_tfhe/hello_world_clean.mlir @@ -5,11 +5,11 @@ module attributes {tf_saved_model.semantics} { func.func @main(%arg0: tensor<1x1xi8> {iree.identifier = "serving_default_dense_input:0", tf_saved_model.index_path = ["dense_input"]}) -> (tensor<1x1xi32> {iree.identifier = "StatefulPartitionedCall:0", tf_saved_model.index_path = ["dense_2"]}) attributes {tf_saved_model.exported_names = ["serving_default"]} { %0 = "tosa.const"() {value = dense<429> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tosa.const"() {value = dense<[[-39, 59, 39, 21, 28, -32, -34, -35, 15, 27, -59, -41, 18, -35, -7, 127]]> : tensor<1x16xi8>} : () -> tensor<1x16xi8> - %2 = "tosa.const"() {value = dense<[-729, 1954, 610, 0, 241, -471, -35, -867, 571, 581, 4260, 3943, 591, 0, -889, -5103]> : tensor<16xi32>} : () -> tensor<16xi32> + %1 = "tosa.const"() {value = dense<[[39, 59, 39, 21, 28, 32, 34, 35, 15, 27, 59, 41, 18, 35, 7, 127]]> : tensor<1x16xi8>} : () -> tensor<1x16xi8> + %2 = "tosa.const"() {value = dense<[729, 1954, 610, 0, 241, 471, 35, 867, 571, 581, 4260, 3943, 591, 0, 889, 5103]> : tensor<16xi32>} : () -> tensor<16xi32> %3 = "tosa.const"() {value = dense<"0xF41AED091921F424E021EFBCF7F5FA1903DCD20206F9F402FFFAEFF1EFD327E1FB27DDEBDBE4051A17FC241215EF1EE410FE14DA1CF8F3F1EFE2F309E3E9EDE3E415070B041B1AFEEB01DE21E60BEC03230A22241E2703E60324FFC011F8FCF1110CF5E0F30717E5E8EDFADCE823FB07DDFBFD0014261117E7F111EA0226040425211D0ADB1DDC2001FAE3370BF11A16EF1CE703E01602032118092ED9E5140BEA1AFCD81300C4D8ECD9FE0D1920D8D6E21FE9D7CAE2DDC613E7043E000114C7DBE71515F506D61ADC0922FE080213EF191EE209FDF314DDDA20D90FE3F9F7EEE924E629000716E21E0D23D3DDF714FA0822262109080F0BE012F47FDC58E526"> : tensor<16x16xi8>} : () -> tensor<16x16xi8> - %4 = "tosa.const"() {value = dense<[0, 0, -5438, -5515, -1352, -1500, -4152, -84, 3396, 0, 1981, -5581, 0, -6964, 3407, -7217]> : tensor<16xi32>} : () -> tensor<16xi32> - %5 = "tosa.const"() {value = dense<[[-9], [-54], [57], [71], [104], [115], [98], [99], [64], [-26], [127], [25], [-82], [68], [95], [86]]> : tensor<16x1xi8>} : () -> tensor<16x1xi8> + %4 = "tosa.const"() {value = dense<[0, 0, 5438, 5515, 1352, 1500, 4152, 84, 3396, 0, 1981, 5581, 0, 6964, 3407, 7217]> : tensor<16xi32>} : () -> tensor<16xi32> + %5 = "tosa.const"() {value = dense<[[9], [54], [57], [71], [104], [115], [98], [99], [64], [26], [127], [25], [82], [68], [95], [86]]> : tensor<16x1xi8>} : () -> tensor<16x1xi8> %6 = "tosa.fully_connected"(%arg0, %5, %4) {quantization_info = #tosa.conv_quant} : (tensor<1x1xi8>, tensor<16x1xi8>, tensor<16xi32>) -> tensor<1x16xi32> %9 = "tosa.fully_connected"(%6, %3, %2) {quantization_info = #tosa.conv_quant} : (tensor<1x16xi32>, tensor<16x16xi8>, tensor<16xi32>) -> tensor<1x16xi32> %12 = "tosa.fully_connected"(%9, %1, %0) {quantization_info = #tosa.conv_quant} : (tensor<1x16xi32>, tensor<1x16xi8>, tensor<1xi32>) -> tensor<1x1xi32>