Skip to content

Commit

Permalink
Merge pull request #1081 from WoutLegiest:not_packed
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696131357
  • Loading branch information
copybara-github committed Nov 13, 2024
2 parents 5dac435 + 8822094 commit ddfddb4
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 22 deletions.
1 change: 1 addition & 0 deletions lib/Dialect/CGGI/IR/BooleanGates.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ defvar OR_GATE = 2;
defvar NOR_GATE = 3;
defvar XOR_GATE = 4;
defvar XNOR_GATE = 5;
defvar NOT_GATE = 6;
4 changes: 3 additions & 1 deletion lib/Dialect/CGGI/IR/CGGIEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ def BOOL_GATE_OR: I32EnumAttrCase<"OR", OR_GATE>;
def BOOL_GATE_NOR: I32EnumAttrCase<"NOR", NOR_GATE>;
def BOOL_GATE_XOR: I32EnumAttrCase<"XOR", XOR_GATE>;
def BOOL_GATE_XNOR: I32EnumAttrCase<"XNOR", XNOR_GATE>;
def BOOL_GATE_NOT: I32EnumAttrCase<"NOT", NOT_GATE>;

def CGGI_BooleanGateEnumAttr : I32EnumAttr<"CGGIBoolGateEnum",
"An enum attribute representing a CGGI boolean gate using u8 int",
[ BOOL_GATE_AND, BOOL_GATE_NAND, BOOL_GATE_OR,
BOOL_GATE_NOR, BOOL_GATE_XOR, BOOL_GATE_XNOR
BOOL_GATE_NOR, BOOL_GATE_XOR, BOOL_GATE_XNOR,
BOOL_GATE_NOT
]>
{
let cppNamespace = "::mlir::heir::cggi";
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/TfheRustBool/IR/TfheRustBoolEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def BOOL_GATE_OR: I32EnumAttrCase<"OR", OR_GATE>;
def BOOL_GATE_NOR: I32EnumAttrCase<"NOR", NOR_GATE>;
def BOOL_GATE_XOR: I32EnumAttrCase<"XOR", XOR_GATE>;
def BOOL_GATE_XNOR: I32EnumAttrCase<"XNOR", XNOR_GATE>;
def BOOL_GATE_NOT: I32EnumAttrCase<"NOT", NOT_GATE>;

// Enum definition is done in CGGIEnums.td
def TfheRustBool_BooleanGateEnumAttr : I32EnumAttr<"TfheRustBoolGateEnum",
"An enum attribute representing a TFHE-rs boolean gate using u8 int",
[ BOOL_GATE_AND, BOOL_GATE_NAND, BOOL_GATE_OR,
BOOL_GATE_NOR, BOOL_GATE_XOR, BOOL_GATE_XNOR
BOOL_GATE_NOR, BOOL_GATE_XOR, BOOL_GATE_XNOR,
BOOL_GATE_NOT
]>
{
let cppNamespace = "::mlir::heir::tfhe_rust_bool";
Expand Down
28 changes: 17 additions & 11 deletions lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,20 @@ LogicalResult TfheRustBoolEmitter::printSksMethod(
}

emitAssignPrefix(result);
os << variableNames->getNameForValue(sks);

os << variableNames->getNameForValue(sks) << ".packed_gates(\n";
os << "&vec![";
// parse the not gate
if (!gateStr.compare("NOT")) {
os << ".packed_not(\n";
} else {
os << ".packed_gates( \n &vec![";

for (size_t i = 0; i < numberOfOperands; i++) {
os << "Gate::" << gateStr << ", ";
}
for (size_t i = 0; i < numberOfOperands; i++) {
os << "Gate::" << gateStr << ", ";
}

os << "],\n";
os << "],\n";
}

os << commaSeparatedValues(
nonSksOperands, [&, numberOfOperands](Value value) {
Expand All @@ -297,6 +302,7 @@ LogicalResult TfheRustBoolEmitter::printSksMethod(
} else {
os << ");\n";
}

return success();
}
// Check that this translation can only be used by non-tensor operands
Expand Down Expand Up @@ -525,11 +531,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) {
return success();
}

LogicalResult TfheRustBoolEmitter::printOperation(NotOp op) {
return printSksMethod(op.getResult(), op.getServerKey(), {op.getInput()},
"not");
}

LogicalResult TfheRustBoolEmitter::printOperation(AndOp op) {
return printSksMethod(op.getResult(), op.getServerKey(),
{op.getLhs(), op.getRhs()}, "and");
Expand Down Expand Up @@ -586,6 +587,11 @@ LogicalResult TfheRustBoolEmitter::printOperation(PackedOp op) {
return success();
}

LogicalResult TfheRustBoolEmitter::printOperation(NotOp op) {
return printSksMethod(op.getResult(), op.getServerKey(), {op.getInput()},
"not");
}

FailureOr<std::string> TfheRustBoolEmitter::convertType(Type type) {
// Note: these are probably not the right type names to use exactly, and
// they will need to chance to the right values once we try to compile it
Expand Down
12 changes: 12 additions & 0 deletions tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,15 @@ func.func @test_and(%bsks : !bsks, %input1 : !eb, %input2 : !eb) -> !eb {
%out = tfhe_rust_bool.and %bsks, %input1, %input2 : (!bsks, !eb, !eb) -> !eb
return %out : !eb
}

// CHECK-LABEL: pub fn test_not(
// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey,
// CHECK-NEXT: [[input1:v[0-9]+]]: &Ciphertext,
// CHECK-NEXT: ) -> Ciphertext {
// CHECK-NEXT: let [[v0:.*]] = [[bsks]].not([[input1]]);
// CHECK-NEXT: [[v0]]
// CHECK-NEXT: }
func.func @test_not(%bsks : !bsks, %input1 : !eb) -> !eb {
%out = tfhe_rust_bool.not %bsks, %input1 : (!bsks, !eb) -> !eb
return %out : !eb
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: heir-translate %s --emit-tfhe-rust-bool-packed | FileCheck %s

!bsks = !tfhe_rust_bool.server_key
!eb = !tfhe_rust_bool.eb

module{
// CHECK-LABEL: pub fn test_and(
// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey,
// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec<Ciphertext>,
// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec<Ciphertext>,
// CHECK-NEXT: ) -> Vec<Ciphertext> {
// CHECK-NEXT: let [[input1]]_ref = [[input1]].clone();
// CHECK-NEXT: let [[input1]]_ref: Vec<&Ciphertext> = [[input1]].iter().collect();
// CHECK-NEXT: let [[input2]]_ref = [[input2]].clone();
// CHECK-NEXT: let [[input2]]_ref: Vec<&Ciphertext> = [[input2]].iter().collect();
// CHECK-NEXT: let [[v0:.*]] = [[bsks]].packed_gates(
// CHECK-NEXT: &vec![Gate::AND, Gate::AND, Gate::AND, Gate::AND, ],
// CHECK-NEXT: &[[input1]]_ref, &[[input2]]_ref);
// CHECK-NEXT: [[v0]]
// CHECK-NEXT: }
func.func @test_and(%bsks : !bsks, %input1 : tensor<4x!eb>, %input2 : tensor<4x!eb>) -> tensor<4x!eb> {
%out = tfhe_rust_bool.and %bsks, %input1, %input2 : (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb>
return %out : tensor<4x!eb>
}

// CHECK-LABEL: pub fn test_not(
// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey,
// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec<Ciphertext>,
// CHECK-NEXT: ) -> Vec<Ciphertext> {
// CHECK-NEXT: let [[input1]]_ref = [[input1]].clone();
// CHECK-NEXT: let [[input1]]_ref: Vec<&Ciphertext> = [[input1]].iter().collect();
// CHECK-NEXT: let [[v0:.*]] = [[bsks]].packed_not(
// CHECK-NEXT: &[[input1]]_ref);
// CHECK-NEXT: [[v0]]
// CHECK-NEXT: }
func.func @test_not(%bsks : !bsks, %input1 : tensor<4x!eb>) -> tensor<4x!eb>{
%out = tfhe_rust_bool.not %bsks, %input1 : (!bsks, tensor<4x!eb>) -> tensor<4x!eb>
return %out : tensor<4x!eb>
}

}
9 changes: 6 additions & 3 deletions tests/Dialect/TfheRustBool/IR/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@ module {
return
}

// CHECK-LABEL: func @test_packed_and
func.func @test_packed_and(%bsks : !bsks, %lhs : tensor<4x!eb>, %rhs : tensor<4x!eb>) {
%out = tfhe_rust_bool.and %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb>
// CHECK-LABEL: func @test_not
func.func @test_not(%bsks : !bsks) {
%0 = arith.constant 1 : i1

%e1 = tfhe_rust_bool.create_trivial %bsks, %0 : (!bsks, i1) -> !tfhe_rust_bool.eb
%out = tfhe_rust_bool.not %bsks, %e1: (!bsks, !tfhe_rust_bool.eb) -> !tfhe_rust_bool.eb
return
}
}
16 changes: 16 additions & 0 deletions tests/Examples/tfhe_rust_bool/fpga/src/server_key_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub enum ServerKeyEnum {
pub trait ServerKeyTrait {
fn packed_gates(&self, gates: &Vec<Gate>, cts_left: &Vec<&Ciphertext>, cts_right: &Vec<&Ciphertext>) -> Vec<Ciphertext>;
fn not(&self, ct: &Ciphertext) -> Ciphertext;
fn packed_not(&self, cts: &Vec<&Ciphertext>) -> Vec<Ciphertext>;
fn trivial_encrypt(&self, value: bool) -> Ciphertext;
}

Expand All @@ -22,6 +23,10 @@ impl ServerKeyTrait for ServerKey {
return self.not(ct);
}

fn packed_not(&self, cts: &Vec<&Ciphertext>) -> Vec<Ciphertext> {
return self.packed_not(cts);
}

fn trivial_encrypt(&self, value: bool) -> Ciphertext {
return self.trivial_encrypt(value);
}
Expand All @@ -36,6 +41,10 @@ impl ServerKeyTrait for BelfortBooleanServerKey {
return self.not(ct);
}

fn packed_not(&self, cts: &Vec<&Ciphertext>) -> Vec<Ciphertext> {
return self.packed_not(cts);
}

fn trivial_encrypt(&self, value: bool) -> Ciphertext {
return self.trivial_encrypt(value);
}
Expand All @@ -57,6 +66,13 @@ impl ServerKeyTrait for ServerKeyEnum {
}
}

fn packed_not(&self, cts: &Vec<&Ciphertext>) -> Vec<Ciphertext> {
match self {
ServerKeyEnum::TypeSW(sk) => sk.packed_not(cts),
ServerKeyEnum::TypeFPGA(sk) => sk.packed_not(cts),
}
}

fn trivial_encrypt(&self, value: bool) -> Ciphertext {
match self {
ServerKeyEnum::TypeSW(sk) => sk.trivial_encrypt(value),
Expand Down
12 changes: 6 additions & 6 deletions tests/Examples/tfhe_rust_bool/fpga/test_fully_connected.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
// RUN: heir-opt --tosa-to-boolean-fpga-tfhe="abc-fast=true entry-function=fn_under_test" %s | heir-translate --emit-tfhe-rust-bool-packed > %S/src/fn_under_test.rs
// RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_fully_connected -- 2 | FileCheck %s

// This takes takes the input x and outputs 2 \cdot x + 3.
// This takes takes the input x and outputs a FC layer operation.
// CHECK: 00000111
module attributes {tf_saved_model.semantics} {
func.func @fn_under_test(%11: tensor<1x1xi8>) -> tensor<1x1xi32> {
%0 = "tosa.const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tosa.const"() {value = dense<[[2]]> : tensor<1x1xi8>} : () -> tensor<1x1xi8>
%2 = "tosa.fully_connected"(%11, %1, %0) {quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x1xi8>, tensor<1x1xi8>, tensor<1xi32>) -> tensor<1x1xi32>
return %2 : tensor<1x1xi32>
func.func @fn_under_test(%11: tensor<1x3xi8>) -> tensor<1x3xi32> {
%0 = "tosa.const"() {value = dense<[3, 1, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "tosa.const"() {value = dense<[[2, 7, 1], [8,2,8], [1,8,2]]> : tensor<3x3xi8>} : () -> tensor<3x3xi8>
%2 = "tosa.fully_connected"(%11, %1, %0) {quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x3xi8>, tensor<3x3xi8>, tensor<3xi32>) -> tensor<1x3xi32>
return %2 : tensor<1x3xi32>
}
}

0 comments on commit ddfddb4

Please sign in to comment.