Skip to content

Commit

Permalink
emit-tfhe-rust: support tensors of lwe ciphertexts
Browse files Browse the repository at this point in the history
Signed-off-by: Asra <asraa@google.com>
  • Loading branch information
asraa committed Dec 14, 2023
1 parent afbfe51 commit 9aa32dd
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 50 deletions.
1 change: 1 addition & 0 deletions include/Analysis/SelectVariableNames/SelectVariableNames.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class SelectVariableNames {
/// value was not assigned a name (suggesting the value was not in the IR
/// tree that this class was constructed with).
std::string getNameForValue(Value value) const {
assert(variableNames.contains(value));
return variableNames.lookup(value);
}

Expand Down
29 changes: 17 additions & 12 deletions include/Target/TfheRust/TfheRustEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
#include "include/Analysis/SelectVariableNames/SelectVariableNames.h"
#include "include/Dialect/TfheRust/IR/TfheRustDialect.h"
#include "include/Dialect/TfheRust/IR/TfheRustOps.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand Down Expand Up @@ -49,14 +50,18 @@ class TfheRustEmitter {
LogicalResult printOperation(::mlir::func::FuncOp op);
LogicalResult printOperation(::mlir::func::ReturnOp op);
LogicalResult printOperation(AddOp op);
LogicalResult printOperation(CreateTrivialOp op);
LogicalResult printOperation(tensor::ExtractOp op);
LogicalResult printOperation(tensor::FromElementsOp op);
LogicalResult printOperation(ApplyLookupTableOp op);
LogicalResult printOperation(GenerateLookupTableOp op);
LogicalResult printOperation(ScalarLeftShiftOp op);

// Helpers for above
LogicalResult printSksMethod(::mlir::Value result, ::mlir::Value sks,
::mlir::ValueRange nonSksOperands,
std::string_view op);
std::string_view op,
SmallVector<std::string> operandTypes = {});

// Emit a TfheRust type
LogicalResult emitType(Type type);
Expand Down
1 change: 1 addition & 0 deletions lib/Analysis/SelectVariableNames/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TensorDialect",
],
)
15 changes: 8 additions & 7 deletions lib/Analysis/SelectVariableNames/SelectVariableNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@

#include <string>

#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

namespace mlir {
namespace heir {

SelectVariableNames::SelectVariableNames(Operation *op) {
int i = 0;
std::string prefix = "v";
op->walk([&](Operation *op) {
op->walk<WalkOrder::PreOrder>([&](Operation *op) {
return llvm::TypeSwitch<Operation &, WalkResult>(*op)
// Function arguments need names
.Case<func::FuncOp>([&](auto op) {
Expand Down
1 change: 1 addition & 0 deletions lib/Target/TfheRust/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TranslateLib",
],
)
118 changes: 89 additions & 29 deletions lib/Target/TfheRust/TfheRustEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,22 @@
#include "include/Dialect/TfheRust/IR/TfheRustOps.h"
#include "include/Dialect/TfheRust/IR/TfheRustTypes.h"
#include "lib/Target/TfheRust/TfheRustTemplates.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project

namespace mlir {
Expand All @@ -45,7 +46,7 @@ void registerToTfheRustTranslation() {
},
[](DialectRegistry &registry) {
registry.insert<func::FuncDialect, tfhe_rust::TfheRustDialect,
arith::ArithDialect>();
arith::ArithDialect, tensor::TensorDialect>();
});
}

Expand All @@ -68,7 +69,12 @@ LogicalResult TfheRustEmitter::translate(Operation &op) {
.Case<arith::ConstantOp>([&](auto op) { return printOperation(op); })
// TfheRust ops
.Case<AddOp, ApplyLookupTableOp, GenerateLookupTableOp,
ScalarLeftShiftOp>([&](auto op) { return printOperation(op); })
ScalarLeftShiftOp, CreateTrivialOp>(
[&](auto op) { return printOperation(op); })
// Tensor ops
.Case<tensor::ExtractOp, tensor::FromElementsOp>(
[&](auto op) { return printOperation(op); })

.Default([&](Operation &) {
return op.emitOpError("unable to find printer for op");
});
Expand Down Expand Up @@ -161,30 +167,39 @@ LogicalResult TfheRustEmitter::printOperation(func::FuncOp funcOp) {
}

LogicalResult TfheRustEmitter::printOperation(func::ReturnOp op) {
std::function<std::string(Value)> valueOrClonedValue = [&](Value value) {
auto cloneStr = "";
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
}
return variableNames->getNameForValue(value) + cloneStr;
};

if (op.getNumOperands() == 1) {
os << variableNames->getNameForValue(op.getOperands()[0]) << "\n";
os << valueOrClonedValue(op.getOperands()[0]) << "\n";
return success();
}

os << "(" << commaSeparatedValues(op.getOperands(), [&](Value value) {
return variableNames->getNameForValue(value);
}) << ")\n";
os << "(" << commaSeparatedValues(op.getOperands(), valueOrClonedValue)
<< ")\n";
return success();
}

void TfheRustEmitter::emitAssignPrefix(Value result) {
os << "let " << variableNames->getNameForValue(result) << " = ";
}

LogicalResult TfheRustEmitter::printSksMethod(::mlir::Value result,
::mlir::Value sks,
::mlir::ValueRange nonSksOperands,
std::string_view op) {
LogicalResult TfheRustEmitter::printSksMethod(
::mlir::Value result, ::mlir::Value sks, ::mlir::ValueRange nonSksOperands,
std::string_view op, SmallVector<std::string> operandTypes) {
emitAssignPrefix(result);

auto operandTypesIt = operandTypes.begin();
os << variableNames->getNameForValue(sks) << "." << op << "(";
os << commaSeparatedValues(nonSksOperands, [&](Value value) {
const auto *prefix = value.getType().hasTrait<PassByReference>() ? "&" : "";
return prefix + variableNames->getNameForValue(value);
return prefix + variableNames->getNameForValue(value) +
(!operandTypes.empty() ? " as " + *operandTypesIt++ : "");
});
os << ");\n";
return success();
Expand All @@ -203,12 +218,12 @@ LogicalResult TfheRustEmitter::printOperation(ApplyLookupTableOp op) {

LogicalResult TfheRustEmitter::printOperation(GenerateLookupTableOp op) {
auto sks = op.getServerKey();
APInt truthTable = op.getTruthTable().getValue();
uint64_t truthTable = op.getTruthTable().getUInt();
auto result = op.getResult();

emitAssignPrefix(result);
os << variableNames->getNameForValue(sks) << ".generate_lookup_table(";
os << "|x| (" << truthTable << " >> x) & 1";
os << "|x| (" << std::to_string(truthTable) << " >> x) & 1";
os << ");\n";
return success();
}
Expand All @@ -219,9 +234,23 @@ LogicalResult TfheRustEmitter::printOperation(ScalarLeftShiftOp op) {
"scalar_left_shift");
}

LogicalResult TfheRustEmitter::printOperation(CreateTrivialOp op) {
return printSksMethod(op.getResult(), op.getServerKey(), {op.getValue()},
"create_trivial", {"u64"});
}

LogicalResult TfheRustEmitter::printOperation(arith::ConstantOp op) {
emitAssignPrefix(op.getResult());
auto valueAttr = op.getValue();
if (isa<IntegerType>(op.getType()) &&
op.getType().getIntOrFloatBitWidth() == 1) {
os << "let " << variableNames->getNameForValue(op.getResult())
<< " : bool = ";
os << (cast<IntegerAttr>(valueAttr).getValue().isZero() ? "false" : "true")
<< ";\n";
return success();
}

emitAssignPrefix(op.getResult());
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
os << intAttr.getValue() << ";\n";
} else {
Expand All @@ -230,10 +259,41 @@ LogicalResult TfheRustEmitter::printOperation(arith::ConstantOp op) {
return success();
}

LogicalResult TfheRustEmitter::printOperation(tensor::ExtractOp op) {
// We assume here that the indices are SSA values (not integer attributes).
emitAssignPrefix(op.getResult());
os << "&" << variableNames->getNameForValue(op.getTensor()) << "["
<< commaSeparatedValues(
op.getIndices(),
[&](Value value) { return variableNames->getNameForValue(value); })
<< "];\n";
return success();
}

LogicalResult TfheRustEmitter::printOperation(tensor::FromElementsOp op) {
emitAssignPrefix(op.getResult());
os << "vec![" << commaSeparatedValues(op.getOperands(), [&](Value value) {
// Check if block argument, if so, clone.
auto cloneStr = "";
if (isa<BlockArgument>(value)) {
cloneStr = ".clone()";
}
return variableNames->getNameForValue(value) + cloneStr;
}) << "];\n";
return success();
}

FailureOr<std::string> TfheRustEmitter::convertType(Type type) {
// Note: these are probably not the right type names to use exactly, and they
// will need to chance to the right values once we try to compile it against
// a specific API version.
if (auto shapedType = dyn_cast<ShapedType>(type)) {
// A lambda in a type switch statement can't return multiple types.
// FIXME: why can't both types be FailureOr<std::string>?
auto elementTy = convertType(shapedType.getElementType());
if (failed(elementTy)) return failure();
return std::string("Vec<" + elementTy.value() + ">");
}
return llvm::TypeSwitch<Type &, FailureOr<std::string>>(type)
.Case<EncryptedUInt3Type>(
[&](auto type) { return std::string("Ciphertext"); })
Expand Down
2 changes: 1 addition & 1 deletion tests/tfhe_rust/emit_tfhe_rust.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func.func @test_apply_lookup_table2(%sks : !sks, %lut: !lut, %input : !eui3) ->
// CHECK-LABEL: pub fn test_return_multiple_values(
// CHECK-NEXT: [[input:v[0-9]+]]: &Ciphertext,
// CHECK-NEXT: ) -> (Ciphertext, Ciphertext) {
// CHECK-NEXT: ([[input]], [[input]])
// CHECK-NEXT: ([[input]].clone(), [[input]].clone())
// CHECK-NEXT: }
func.func @test_return_multiple_values(%input : !eui3) -> (!eui3, !eui3) {
return %input, %input : !eui3, !eui3
Expand Down
1 change: 1 addition & 0 deletions tests/tfhe_rust/end_to_end/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ glob_lit_tests(
data = [
"Cargo.toml",
"src/main.rs",
"src/main_add_one.rs",
"@heir//tests:test_utilities",
],
default_tags = [
Expand Down
4 changes: 4 additions & 0 deletions tests/tfhe_rust/end_to_end/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ tfhe = { version = "0.4.1", features = ["boolean", "shortint", "x86_64-unix"] }
[[bin]]
name = "main"
path = "src/main.rs"

[[bin]]
name = "main_add_one"
path = "src/main_add_one.rs"
52 changes: 52 additions & 0 deletions tests/tfhe_rust/end_to_end/src/main_add_one.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use clap::Parser;
use tfhe::shortint::*;
use tfhe::shortint::parameters::get_parameters_from_message_and_carry;

mod fn_under_test;

// TODO(https://github.com/google/heir/issues/235): improve generality
#[derive(Parser, Debug)]
struct Args {
#[arg(id = "message_bits", long)]
message_bits: usize,

#[arg(id = "carry_bits", long, default_value = "2")]
carry_bits: usize,

/// arguments to forward to function under test
#[arg(id = "input_1", index = 1)]
input1: u8,
}

// Encrypt a u8
pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec<Ciphertext> {
(0..8)
.map(|shift| {
let bit = (value >> shift) & 1;
client_key.encrypt(if bit != 0 { 1 } else { 0 })
})
.collect()
}

// Decrypt a u8
pub fn decrypt(ciphertexts: &[Ciphertext], client_key: &ClientKey) -> u8 {
let mut accum = 0u8;
for (i, ct) in ciphertexts.iter().enumerate() {
let bit = client_key.decrypt(ct);
accum |= (bit as u8) << i;
}
accum
}

fn main() {
let flags = Args::parse();
let parameters = get_parameters_from_message_and_carry((1 << flags.message_bits) - 1, flags.carry_bits);
let (client_key, server_key) = tfhe::shortint::gen_keys(parameters);

let ct_1 = encrypt(flags.input1.into(), &client_key);

let result = fn_under_test::fn_under_test(&server_key, &ct_1);
let output = decrypt(&result, &client_key);

println!("{:?}", output);
}
Loading

0 comments on commit 9aa32dd

Please sign in to comment.