Skip to content

Commit

Permalink
Working conversion heir-tosa to HL emitter
Browse files Browse the repository at this point in the history
  • Loading branch information
WoutLegiest committed Jan 16, 2025
1 parent ce95967 commit fdcebe5
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 67 deletions.
24 changes: 22 additions & 2 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ struct ConvertExtUIOp : public OpConversionPattern<mlir::arith::ExtUIOp> {
}
};

struct ConvertExtSIOp : public OpConversionPattern<mlir::arith::ExtSIOp> {
ConvertExtSIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ExtSIOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ExtSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto outType = convertArithToCGGIType(
cast<IntegerType>(op.getResult().getType()), op->getContext());
auto castOp = b.create<cggi::CastOp>(op.getLoc(), outType, adaptor.getIn());

rewriter.replaceOp(op, castOp);
return success();
}
};

struct ConvertShRUIOp : public OpConversionPattern<mlir::arith::ShRUIOp> {
ConvertShRUIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ShRUIOp>(context) {}
Expand Down Expand Up @@ -192,8 +212,8 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

patterns.add<
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertShRUIOp,
ConvertBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp,
ConvertShRUIOp, ConvertBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
Expand Down
52 changes: 8 additions & 44 deletions lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,48 +42,6 @@ constexpr int kAndLut = 8;
constexpr int kOrLut = 14;
constexpr int kXorLut = 6;

static Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) {
// Only supporting unsigned types because the LWE dialect does not have a
// notion of signedness.
switch (width) {
case 1:
// The minimum bit width of the integer tfhe_rust API is UInt2
// https://docs.rs/tfhe/latest/tfhe/index.html#types
// This may happen if there are no LUT or boolean gate operations that
// require a minimum bit width (e.g. shuffling bits in a program that
// multiplies by two).
LLVM_DEBUG(llvm::dbgs()
<< "Upgrading ciphertext with bit width 1 to UInt2");
[[fallthrough]];
case 2:
return tfhe_rust::EncryptedUInt2Type::get(ctx);
case 3:
return tfhe_rust::EncryptedUInt3Type::get(ctx);
case 4:
return tfhe_rust::EncryptedUInt4Type::get(ctx);
case 8:
return tfhe_rust::EncryptedUInt8Type::get(ctx);
case 10:
return tfhe_rust::EncryptedUInt10Type::get(ctx);
case 12:
return tfhe_rust::EncryptedUInt12Type::get(ctx);
case 14:
return tfhe_rust::EncryptedUInt14Type::get(ctx);
case 16:
return tfhe_rust::EncryptedUInt16Type::get(ctx);
case 32:
return tfhe_rust::EncryptedUInt32Type::get(ctx);
case 64:
return tfhe_rust::EncryptedUInt64Type::get(ctx);
case 128:
return tfhe_rust::EncryptedUInt128Type::get(ctx);
case 256:
return tfhe_rust::EncryptedUInt256Type::get(ctx);
default:
llvm_unreachable("Unsupported bitwidth");
}
}

class CGGIToTfheRustTypeConverter : public TypeConverter {
public:
CGGIToTfheRustTypeConverter(MLIRContext *ctx) {
Expand Down Expand Up @@ -532,6 +490,12 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase<CGGIToTfheRust> {
hasServerKeyArg);
});

target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
bool hasServerKeyArg =
isa<tfhe_rust::ServerKeyType>(op.getOperand(0).getType());
return hasServerKeyArg;
});

target.addLegalOp<mlir::arith::ConstantOp>();

target.addDynamicallyLegalOp<
Expand All @@ -546,8 +510,8 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase<CGGIToTfheRust> {
// FIXME: still need to update callers to insert the new server key arg, if
// needed and possible.
patterns.add<
AddServerKeyArg, ConvertEncodeOp, ConvertLut2Op, ConvertLut3Op,
ConvertNotOp, ConvertTrivialEncryptOp, ConvertTrivialOp,
AddServerKeyArg, AddServerKeyArgCall, ConvertEncodeOp, ConvertLut2Op,
ConvertLut3Op, ConvertNotOp, ConvertTrivialEncryptOp, ConvertTrivialOp,
ConvertCGGITRBinOp<cggi::AddOp, tfhe_rust::AddOp>,
ConvertCGGITRBinOp<cggi::MulOp, tfhe_rust::MulOp>,
ConvertCGGITRBinOp<cggi::SubOp, tfhe_rust::SubOp>, ConvertAndOp,
Expand Down
16 changes: 8 additions & 8 deletions lib/Target/TfheRust/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ LogicalResult canEmitFuncForTfheRust(func::FuncOp &funcOp) {
return TypeSwitch<Operation *, WalkResult>(op)
// This list should match the list of implemented overloads of
// `printOperation`.
.Case<ModuleOp, func::FuncOp, func::ReturnOp, affine::AffineForOp,
affine::AffineYieldOp, affine::AffineLoadOp,
.Case<ModuleOp, func::FuncOp, func::ReturnOp, func::CallOp,
affine::AffineForOp, affine::AffineYieldOp, affine::AffineLoadOp,
affine::AffineStoreOp, arith::ConstantOp, arith::IndexCastOp,
arith::ShLIOp, arith::AndIOp, arith::ShRSIOp, arith::TruncIOp,
tensor::ExtractOp, tensor::FromElementsOp, memref::AllocOp,
memref::DeallocOp, memref::GetGlobalOp, memref::LoadOp,
memref::StoreOp, AddOp, SubOp, BitAndOp, CreateTrivialOp,
ApplyLookupTableOp, GenerateLookupTableOp, ScalarLeftShiftOp,
ScalarRightShiftOp, CastOp, MulOp,
::mlir::heir::tfhe_rust_bool::CreateTrivialOp,
tensor::ExtractOp, tensor::FromElementsOp, tensor::InsertOp,
memref::AllocOp, memref::DeallocOp, memref::DeallocOp,
memref::GetGlobalOp, memref::LoadOp, memref::StoreOp, AddOp,
SubOp, BitAndOp, CreateTrivialOp, ApplyLookupTableOp,
GenerateLookupTableOp, ScalarLeftShiftOp, ScalarRightShiftOp,
CastOp, MulOp, ::mlir::heir::tfhe_rust_bool::CreateTrivialOp,
::mlir::heir::tfhe_rust_bool::AndOp,
::mlir::heir::tfhe_rust_bool::PackedOp,
::mlir::heir::tfhe_rust_bool::NandOp,
Expand Down
66 changes: 59 additions & 7 deletions lib/Target/TfheRustHL/TfheRustHLEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ LogicalResult TfheRustHLEmitter::translate(Operation &op) {
// Builtin ops
.Case<ModuleOp>([&](auto op) { return printOperation(op); })
// Func ops
.Case<func::FuncOp, func::ReturnOp>(
.Case<func::FuncOp, func::CallOp, func::ReturnOp>(
[&](auto op) { return printOperation(op); })
// Affine ops
.Case<affine::AffineForOp, affine::AffineYieldOp,
Expand All @@ -107,13 +107,13 @@ LogicalResult TfheRustHLEmitter::translate(Operation &op) {
arith::ShLIOp, arith::TruncIOp, arith::AndIOp>(
[&](auto op) { return printOperation(op); })
// MemRef ops
.Case<memref::AllocOp, memref::LoadOp, memref::StoreOp>(
[&](auto op) { return printOperation(op); })
.Case<memref::AllocOp, memref::DeallocOp, memref::LoadOp,
memref::StoreOp>([&](auto op) { return printOperation(op); })
// TfheRust ops
.Case<AddOp, SubOp, MulOp, ScalarRightShiftOp, CastOp,
CreateTrivialOp>([&](auto op) { return printOperation(op); })
// Tensor ops
.Case<tensor::ExtractOp, tensor::FromElementsOp>(
.Case<tensor::ExtractOp, tensor::FromElementsOp, tensor::InsertOp>(
[&](auto op) { return printOperation(op); })
.Default([&](Operation &) {
return op.emitOpError("unable to find printer for op");
Expand Down Expand Up @@ -252,6 +252,25 @@ LogicalResult TfheRustHLEmitter::printOperation(func::ReturnOp op) {
return success();
}

LogicalResult TfheRustHLEmitter::printOperation(func::CallOp op) {
os << "let " << variableNames->getNameForValue(op->getResult(0)) << " = ";

os << op.getCallee() << "(";
for (Value arg : op->getOperands()) {
if (!isa<tfhe_rust::ServerKeyType>(arg.getType())) {
auto argName = variableNames->getNameForValue(arg);
if (op.getOperands().back() == arg) {
os << "&" << argName;
} else {
os << "&" << argName << ", ";
}
}
}

os << "); \n";
return success();
}

void TfheRustHLEmitter::emitAssignPrefix(Value result) {
os << "let " << variableNames->getNameForValue(result) << " = ";
}
Expand Down Expand Up @@ -286,7 +305,9 @@ LogicalResult TfheRustHLEmitter::printMethod(

LogicalResult TfheRustHLEmitter::printOperation(CreateTrivialOp op) {
emitAssignPrefix(op.getResult());
os << "FheUint" << DefaultTfheRustHLBitWidth << "::try_encrypt_trivial("

os << "FheUint" << getTfheRustBitWidth(op.getResult().getType())
<< "::try_encrypt_trivial("
<< variableNames->getNameForValue(op.getValue()) << ").unwrap();\n";
return success();
}
Expand Down Expand Up @@ -334,9 +355,10 @@ LogicalResult TfheRustHLEmitter::printOperation(arith::ConstantOp op) {
return success();
}

// FIXME: By default, it emits an unsigned integer.
emitAssignPrefix(op.getResult());
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
os << intAttr.getValue() << "u64;\n";
os << intAttr.getValue().abs() << "u64;\n";
} else {
return op.emitError() << "Unknown constant type " << valueAttr.getType();
}
Expand Down Expand Up @@ -412,6 +434,12 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::AllocOp op) {
return success();
}

// Use a BTreeMap<(usize, ...), Ciphertext>.
LogicalResult TfheRustHLEmitter::printOperation(memref::DeallocOp op) {
os << variableNames->getNameForValue(op.getMemref()) << ".clear();\n";
return success();
}

// Store into a BTreeMap<(usize, ...), Ciphertext>
LogicalResult TfheRustHLEmitter::printOperation(memref::StoreOp op) {
// We assume here that the indices are SSA values (not integer attributes).
Expand Down Expand Up @@ -544,6 +572,25 @@ LogicalResult TfheRustHLEmitter::printOperation(tensor::FromElementsOp op) {
return success();
}

// Need to produce a
LogicalResult TfheRustHLEmitter::printOperation(tensor::InsertOp op) {
// emitAssignPrefix(op.getResult());
// os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) {
// // Check if block argument, if so, clone.
// const auto *cloneStr = isa<BlockArgument>(value) ? ".clone()" : "";
// // Get the name of defining operation its dialect
// auto tfheOp =
// value.getDefiningOp()->getDialect()->getNamespace() ==
// "tfhe_rust_bool";
// const auto *prefix = tfheOp ? "&" : "";
// return std::string(prefix) + variableNames->getNameForValue(value) +
// cloneStr;
// }) << "];\n";
os << "Not implemented yet\n";

return success();
}

LogicalResult TfheRustHLEmitter::printOperation(BitAndOp op) {
return printBinaryOp(op.getResult(), op.getLhs(), op.getRhs(), "&&");
}
Expand Down Expand Up @@ -585,7 +632,12 @@ FailureOr<std::string> TfheRustHLEmitter::convertType(Type type) {
// against a specific API version.

if (type.hasTrait<EncryptedInteger>()) {
return std::string("Ciphertext");
auto ctxtWidth = getTfheRustBitWidth(type);
if (ctxtWidth == DefaultTfheRustHLBitWidth) {
return std::string("Ciphertext");
}
return "tfhe::FheUint<tfhe::FheUint" + std::to_string(ctxtWidth) + "Id>";
;
}

return llvm::TypeSwitch<Type &, FailureOr<std::string>>(type)
Expand Down
3 changes: 3 additions & 0 deletions lib/Target/TfheRustHL/TfheRustHLEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TfheRustHLEmitter {
LogicalResult printOperation(::mlir::ModuleOp op);
LogicalResult printOperation(::mlir::func::FuncOp op);
LogicalResult printOperation(::mlir::func::ReturnOp op);
LogicalResult printOperation(::mlir::func::CallOp op);
LogicalResult printOperation(affine::AffineForOp op);
LogicalResult printOperation(affine::AffineYieldOp op);
LogicalResult printOperation(affine::AffineStoreOp op);
Expand All @@ -63,7 +64,9 @@ class TfheRustHLEmitter {
LogicalResult printOperation(arith::TruncIOp op);
LogicalResult printOperation(tensor::ExtractOp op);
LogicalResult printOperation(tensor::FromElementsOp op);
LogicalResult printOperation(tensor::InsertOp op);
LogicalResult printOperation(memref::AllocOp op);
LogicalResult printOperation(memref::DeallocOp op);
LogicalResult printOperation(memref::LoadOp op);
LogicalResult printOperation(memref::StoreOp op);
LogicalResult printOperation(AddOp op);
Expand Down
3 changes: 1 addition & 2 deletions lib/Target/TfheRustHL/TfheRustHLTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ namespace tfhe_rust {

constexpr std::string_view kModulePrelude = R"rust(
use std::collections::BTreeMap;
use tfhe::{FheUint8, FheUint16, FheUint32, FheUint64};
use tfhe::{FheUint4, FheUint8, FheUint16, FheUint32, FheUint64};
use tfhe::prelude::*;
use tfhe::ServerKey;
)rust";

} // namespace tfhe_rust
Expand Down
1 change: 1 addition & 0 deletions lib/Utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cc_library(
"@heir//lib/Dialect/Mgmt/IR:Dialect",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
Expand Down
46 changes: 46 additions & 0 deletions lib/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
#include "lib/Dialect/Mgmt/IR/MgmtOps.h"
#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/TensorExt/IR/TensorExtOps.h"
#include "lib/Dialect/TfheRust/IR/TfheRustTypes.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand All @@ -33,6 +35,8 @@
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

#define DEBUG_TYPE "cggi-to-tfhe-rust"

namespace mlir {
namespace heir {

Expand Down Expand Up @@ -508,6 +512,48 @@ bool containsLweOrDialect(func::FuncOp func) {
return walkResult.wasInterrupted();
}

inline Type encrytpedUIntTypeFromWidth(MLIRContext *ctx, int width) {
// Only supporting unsigned types because the LWE dialect does not have a
// notion of signedness.
switch (width) {
case 1:
// The minimum bit width of the integer tfhe_rust API is UInt2
// https://docs.rs/tfhe/latest/tfhe/index.html#types
// This may happen if there are no LUT or boolean gate operations that
// require a minimum bit width (e.g. shuffling bits in a program that
// multiplies by two).
LLVM_DEBUG(llvm::dbgs()
<< "Upgrading ciphertext with bit width 1 to UInt2");
[[fallthrough]];
case 2:
return tfhe_rust::EncryptedUInt2Type::get(ctx);
case 3:
return tfhe_rust::EncryptedUInt3Type::get(ctx);
case 4:
return tfhe_rust::EncryptedUInt4Type::get(ctx);
case 8:
return tfhe_rust::EncryptedUInt8Type::get(ctx);
case 10:
return tfhe_rust::EncryptedUInt10Type::get(ctx);
case 12:
return tfhe_rust::EncryptedUInt12Type::get(ctx);
case 14:
return tfhe_rust::EncryptedUInt14Type::get(ctx);
case 16:
return tfhe_rust::EncryptedUInt16Type::get(ctx);
case 32:
return tfhe_rust::EncryptedUInt32Type::get(ctx);
case 64:
return tfhe_rust::EncryptedUInt64Type::get(ctx);
case 128:
return tfhe_rust::EncryptedUInt128Type::get(ctx);
case 256:
return tfhe_rust::EncryptedUInt256Type::get(ctx);
default:
llvm_unreachable("Unsupported bitwidth");
}
}

} // namespace heir
} // namespace mlir

Expand Down
8 changes: 4 additions & 4 deletions tests/Transforms/tosa_to_boolean_tfhe/hello_world_clean.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ module attributes {tf_saved_model.semantics} {

func.func @main(%arg0: tensor<1x1xi8> {iree.identifier = "serving_default_dense_input:0", tf_saved_model.index_path = ["dense_input"]}) -> (tensor<1x1xi32> {iree.identifier = "StatefulPartitionedCall:0", tf_saved_model.index_path = ["dense_2"]}) attributes {tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tosa.const"() {value = dense<429> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tosa.const"() {value = dense<[[-39, 59, 39, 21, 28, -32, -34, -35, 15, 27, -59, -41, 18, -35, -7, 127]]> : tensor<1x16xi8>} : () -> tensor<1x16xi8>
%2 = "tosa.const"() {value = dense<[-729, 1954, 610, 0, 241, -471, -35, -867, 571, 581, 4260, 3943, 591, 0, -889, -5103]> : tensor<16xi32>} : () -> tensor<16xi32>
%1 = "tosa.const"() {value = dense<[[39, 59, 39, 21, 28, 32, 34, 35, 15, 27, 59, 41, 18, 35, 7, 127]]> : tensor<1x16xi8>} : () -> tensor<1x16xi8>
%2 = "tosa.const"() {value = dense<[729, 1954, 610, 0, 241, 471, 35, 867, 571, 581, 4260, 3943, 591, 0, 889, 5103]> : tensor<16xi32>} : () -> tensor<16xi32>
%3 = "tosa.const"() {value = dense<"0xF41AED091921F424E021EFBCF7F5FA1903DCD20206F9F402FFFAEFF1EFD327E1FB27DDEBDBE4051A17FC241215EF1EE410FE14DA1CF8F3F1EFE2F309E3E9EDE3E415070B041B1AFEEB01DE21E60BEC03230A22241E2703E60324FFC011F8FCF1110CF5E0F30717E5E8EDFADCE823FB07DDFBFD0014261117E7F111EA0226040425211D0ADB1DDC2001FAE3370BF11A16EF1CE703E01602032118092ED9E5140BEA1AFCD81300C4D8ECD9FE0D1920D8D6E21FE9D7CAE2DDC613E7043E000114C7DBE71515F506D61ADC0922FE080213EF191EE209FDF314DDDA20D90FE3F9F7EEE924E629000716E21E0D23D3DDF714FA0822262109080F0BE012F47FDC58E526"> : tensor<16x16xi8>} : () -> tensor<16x16xi8>
%4 = "tosa.const"() {value = dense<[0, 0, -5438, -5515, -1352, -1500, -4152, -84, 3396, 0, 1981, -5581, 0, -6964, 3407, -7217]> : tensor<16xi32>} : () -> tensor<16xi32>
%5 = "tosa.const"() {value = dense<[[-9], [-54], [57], [71], [104], [115], [98], [99], [64], [-26], [127], [25], [-82], [68], [95], [86]]> : tensor<16x1xi8>} : () -> tensor<16x1xi8>
%4 = "tosa.const"() {value = dense<[0, 0, 5438, 5515, 1352, 1500, 4152, 84, 3396, 0, 1981, 5581, 0, 6964, 3407, 7217]> : tensor<16xi32>} : () -> tensor<16xi32>
%5 = "tosa.const"() {value = dense<[[9], [54], [57], [71], [104], [115], [98], [99], [64], [26], [127], [25], [82], [68], [95], [86]]> : tensor<16x1xi8>} : () -> tensor<16x1xi8>
%6 = "tosa.fully_connected"(%arg0, %5, %4) {quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x1xi8>, tensor<16x1xi8>, tensor<16xi32>) -> tensor<1x16xi32>
%9 = "tosa.fully_connected"(%6, %3, %2) {quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x16xi32>, tensor<16x16xi8>, tensor<16xi32>) -> tensor<1x16xi32>
%12 = "tosa.fully_connected"(%9, %1, %0) {quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x16xi32>, tensor<1x16xi8>, tensor<1xi32>) -> tensor<1x1xi32>
Expand Down

0 comments on commit fdcebe5

Please sign in to comment.