From 8822094040be4ee66dceadacde907cd9363241d5 Mon Sep 17 00:00:00 2001 From: Wouter Legiest Date: Fri, 8 Nov 2024 22:06:59 +0000 Subject: [PATCH] NOT gate emittion in the packed API --- lib/Dialect/CGGI/IR/BooleanGates.td | 1 + lib/Dialect/CGGI/IR/CGGIEnums.td | 4 +- .../TfheRustBool/IR/TfheRustBoolEnums.td | 4 +- .../TfheRustBool/TfheRustBoolEmitter.cpp | 28 ++++++++----- .../Emitters/emit_tfhe_rust_bool.mlir | 12 ++++++ .../Emitters/emit_tfhe_rust_bool_packed.mlir | 41 +++++++++++++++++++ tests/Dialect/TfheRustBool/IR/ops.mlir | 9 ++-- .../fpga/src/server_key_enum.rs | 16 ++++++++ .../fpga/test_fully_connected.mlir | 12 +++--- 9 files changed, 105 insertions(+), 22 deletions(-) create mode 100644 tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool_packed.mlir diff --git a/lib/Dialect/CGGI/IR/BooleanGates.td b/lib/Dialect/CGGI/IR/BooleanGates.td index eb543480f..fca59d333 100644 --- a/lib/Dialect/CGGI/IR/BooleanGates.td +++ b/lib/Dialect/CGGI/IR/BooleanGates.td @@ -4,3 +4,4 @@ defvar OR_GATE = 2; defvar NOR_GATE = 3; defvar XOR_GATE = 4; defvar XNOR_GATE = 5; +defvar NOT_GATE = 6; diff --git a/lib/Dialect/CGGI/IR/CGGIEnums.td b/lib/Dialect/CGGI/IR/CGGIEnums.td index 85780b8bf..34c69462e 100644 --- a/lib/Dialect/CGGI/IR/CGGIEnums.td +++ b/lib/Dialect/CGGI/IR/CGGIEnums.td @@ -17,11 +17,13 @@ def BOOL_GATE_OR: I32EnumAttrCase<"OR", OR_GATE>; def BOOL_GATE_NOR: I32EnumAttrCase<"NOR", NOR_GATE>; def BOOL_GATE_XOR: I32EnumAttrCase<"XOR", XOR_GATE>; def BOOL_GATE_XNOR: I32EnumAttrCase<"XNOR", XNOR_GATE>; +def BOOL_GATE_NOT: I32EnumAttrCase<"NOT", NOT_GATE>; def CGGI_BooleanGateEnumAttr : I32EnumAttr<"CGGIBoolGateEnum", "An enum attribute representing a CGGI boolean gate using u8 int", [ BOOL_GATE_AND, BOOL_GATE_NAND, BOOL_GATE_OR, - BOOL_GATE_NOR, BOOL_GATE_XOR, BOOL_GATE_XNOR + BOOL_GATE_NOR, BOOL_GATE_XOR, BOOL_GATE_XNOR, + BOOL_GATE_NOT ]> { let cppNamespace = "::mlir::heir::cggi"; diff --git a/lib/Dialect/TfheRustBool/IR/TfheRustBoolEnums.td b/lib/Dialect/TfheRustBool/IR/TfheRustBoolEnums.td index 5b61a0b97..55011d613 100644 --- a/lib/Dialect/TfheRustBool/IR/TfheRustBoolEnums.td +++ b/lib/Dialect/TfheRustBool/IR/TfheRustBoolEnums.td @@ -14,12 +14,14 @@ def BOOL_GATE_OR: I32EnumAttrCase<"OR", OR_GATE>; def BOOL_GATE_NOR: I32EnumAttrCase<"NOR", NOR_GATE>; def BOOL_GATE_XOR: I32EnumAttrCase<"XOR", XOR_GATE>; def BOOL_GATE_XNOR: I32EnumAttrCase<"XNOR", XNOR_GATE>; +def BOOL_GATE_NOT: I32EnumAttrCase<"NOT", NOT_GATE>; // Enum definition is done in CGGIEnums.td def TfheRustBool_BooleanGateEnumAttr : I32EnumAttr<"TfheRustBoolGateEnum", "An enum attribute representing a TFHE-rs boolean gate using u8 int", [ BOOL_GATE_AND, BOOL_GATE_NAND, BOOL_GATE_OR, - BOOL_GATE_NOR, BOOL_GATE_XOR, BOOL_GATE_XNOR + BOOL_GATE_NOR, BOOL_GATE_XOR, BOOL_GATE_XNOR, + BOOL_GATE_NOT ]> { let cppNamespace = "::mlir::heir::tfhe_rust_bool"; diff --git a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp index 1a8685838..6dfdce431 100644 --- a/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp +++ b/lib/Target/TfheRustBool/TfheRustBoolEmitter.cpp @@ -268,15 +268,20 @@ LogicalResult TfheRustBoolEmitter::printSksMethod( } emitAssignPrefix(result); + os << variableNames->getNameForValue(sks); - os << variableNames->getNameForValue(sks) << ".packed_gates(\n"; - os << "&vec!["; + // parse the not gate + if (!gateStr.compare("NOT")) { + os << ".packed_not(\n"; + } else { + os << ".packed_gates( \n &vec!["; - for (size_t i = 0; i < numberOfOperands; i++) { - os << "Gate::" << gateStr << ", "; - } + for (size_t i = 0; i < numberOfOperands; i++) { + os << "Gate::" << gateStr << ", "; + } - os << "],\n"; + os << "],\n"; + } os << commaSeparatedValues( nonSksOperands, [&, numberOfOperands](Value value) { @@ -297,6 +302,7 @@ LogicalResult TfheRustBoolEmitter::printSksMethod( } else { os << ");\n"; } + return success(); } // Check that this translation can only be used by non-tensor operands @@ -525,11 +531,6 @@ LogicalResult TfheRustBoolEmitter::printOperation(tensor::FromElementsOp op) { return success(); } -LogicalResult TfheRustBoolEmitter::printOperation(NotOp op) { - return printSksMethod(op.getResult(), op.getServerKey(), {op.getInput()}, - "not"); -} - LogicalResult TfheRustBoolEmitter::printOperation(AndOp op) { return printSksMethod(op.getResult(), op.getServerKey(), {op.getLhs(), op.getRhs()}, "and"); @@ -586,6 +587,11 @@ LogicalResult TfheRustBoolEmitter::printOperation(PackedOp op) { return success(); } +LogicalResult TfheRustBoolEmitter::printOperation(NotOp op) { + return printSksMethod(op.getResult(), op.getServerKey(), {op.getInput()}, + "not"); +} + FailureOr TfheRustBoolEmitter::convertType(Type type) { // Note: these are probably not the right type names to use exactly, and // they will need to chance to the right values once we try to compile it diff --git a/tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool.mlir b/tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool.mlir index 1a96d3b25..66055429a 100644 --- a/tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool.mlir +++ b/tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool.mlir @@ -15,3 +15,15 @@ func.func @test_and(%bsks : !bsks, %input1 : !eb, %input2 : !eb) -> !eb { %out = tfhe_rust_bool.and %bsks, %input1, %input2 : (!bsks, !eb, !eb) -> !eb return %out : !eb } + +// CHECK-LABEL: pub fn test_not( +// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, +// CHECK-NEXT: [[input1:v[0-9]+]]: &Ciphertext, +// CHECK-NEXT: ) -> Ciphertext { +// CHECK-NEXT: let [[v0:.*]] = [[bsks]].not([[input1]]); +// CHECK-NEXT: [[v0]] +// CHECK-NEXT: } +func.func @test_not(%bsks : !bsks, %input1 : !eb) -> !eb { + %out = tfhe_rust_bool.not %bsks, %input1 : (!bsks, !eb) -> !eb + return %out : !eb +} diff --git a/tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool_packed.mlir b/tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool_packed.mlir new file mode 100644 index 000000000..ac9af75a9 --- /dev/null +++ b/tests/Dialect/TfheRustBool/Emitters/emit_tfhe_rust_bool_packed.mlir @@ -0,0 +1,41 @@ +// RUN: heir-translate %s --emit-tfhe-rust-bool-packed | FileCheck %s + +!bsks = !tfhe_rust_bool.server_key +!eb = !tfhe_rust_bool.eb + +module{ +// CHECK-LABEL: pub fn test_and( +// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, +// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, +// CHECK-NEXT: [[input2:v[0-9]+]]: &Vec, +// CHECK-NEXT: ) -> Vec { +// CHECK-NEXT: let [[input1]]_ref = [[input1]].clone(); +// CHECK-NEXT: let [[input1]]_ref: Vec<&Ciphertext> = [[input1]].iter().collect(); +// CHECK-NEXT: let [[input2]]_ref = [[input2]].clone(); +// CHECK-NEXT: let [[input2]]_ref: Vec<&Ciphertext> = [[input2]].iter().collect(); +// CHECK-NEXT: let [[v0:.*]] = [[bsks]].packed_gates( +// CHECK-NEXT: &vec![Gate::AND, Gate::AND, Gate::AND, Gate::AND, ], +// CHECK-NEXT: &[[input1]]_ref, &[[input2]]_ref); +// CHECK-NEXT: [[v0]] +// CHECK-NEXT: } +func.func @test_and(%bsks : !bsks, %input1 : tensor<4x!eb>, %input2 : tensor<4x!eb>) -> tensor<4x!eb> { + %out = tfhe_rust_bool.and %bsks, %input1, %input2 : (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb> + return %out : tensor<4x!eb> +} + +// CHECK-LABEL: pub fn test_not( +// CHECK-NEXT: [[bsks:v[0-9]+]]: &ServerKey, +// CHECK-NEXT: [[input1:v[0-9]+]]: &Vec, +// CHECK-NEXT: ) -> Vec { +// CHECK-NEXT: let [[input1]]_ref = [[input1]].clone(); +// CHECK-NEXT: let [[input1]]_ref: Vec<&Ciphertext> = [[input1]].iter().collect(); +// CHECK-NEXT: let [[v0:.*]] = [[bsks]].packed_not( +// CHECK-NEXT: &[[input1]]_ref); +// CHECK-NEXT: [[v0]] +// CHECK-NEXT: } +func.func @test_not(%bsks : !bsks, %input1 : tensor<4x!eb>) -> tensor<4x!eb>{ + %out = tfhe_rust_bool.not %bsks, %input1 : (!bsks, tensor<4x!eb>) -> tensor<4x!eb> + return %out : tensor<4x!eb> +} + +} diff --git a/tests/Dialect/TfheRustBool/IR/ops.mlir b/tests/Dialect/TfheRustBool/IR/ops.mlir index b5ca17d68..7c9f6ca1f 100644 --- a/tests/Dialect/TfheRustBool/IR/ops.mlir +++ b/tests/Dialect/TfheRustBool/IR/ops.mlir @@ -27,9 +27,12 @@ module { return } - // CHECK-LABEL: func @test_packed_and - func.func @test_packed_and(%bsks : !bsks, %lhs : tensor<4x!eb>, %rhs : tensor<4x!eb>) { - %out = tfhe_rust_bool.and %bsks, %lhs, %rhs: (!bsks, tensor<4x!eb>, tensor<4x!eb>) -> tensor<4x!eb> + // CHECK-LABEL: func @test_not + func.func @test_not(%bsks : !bsks) { + %0 = arith.constant 1 : i1 + + %e1 = tfhe_rust_bool.create_trivial %bsks, %0 : (!bsks, i1) -> !tfhe_rust_bool.eb + %out = tfhe_rust_bool.not %bsks, %e1: (!bsks, !tfhe_rust_bool.eb) -> !tfhe_rust_bool.eb return } } diff --git a/tests/Examples/tfhe_rust_bool/fpga/src/server_key_enum.rs b/tests/Examples/tfhe_rust_bool/fpga/src/server_key_enum.rs index e8fa3efa8..66156f78c 100644 --- a/tests/Examples/tfhe_rust_bool/fpga/src/server_key_enum.rs +++ b/tests/Examples/tfhe_rust_bool/fpga/src/server_key_enum.rs @@ -9,6 +9,7 @@ pub enum ServerKeyEnum { pub trait ServerKeyTrait { fn packed_gates(&self, gates: &Vec, cts_left: &Vec<&Ciphertext>, cts_right: &Vec<&Ciphertext>) -> Vec; fn not(&self, ct: &Ciphertext) -> Ciphertext; + fn packed_not(&self, cts: &Vec<&Ciphertext>) -> Vec; fn trivial_encrypt(&self, value: bool) -> Ciphertext; } @@ -22,6 +23,10 @@ impl ServerKeyTrait for ServerKey { return self.not(ct); } + fn packed_not(&self, cts: &Vec<&Ciphertext>) -> Vec { + return self.packed_not(cts); + } + fn trivial_encrypt(&self, value: bool) -> Ciphertext { return self.trivial_encrypt(value); } @@ -36,6 +41,10 @@ impl ServerKeyTrait for BelfortBooleanServerKey { return self.not(ct); } + fn packed_not(&self, cts: &Vec<&Ciphertext>) -> Vec { + return self.packed_not(cts); + } + fn trivial_encrypt(&self, value: bool) -> Ciphertext { return self.trivial_encrypt(value); } @@ -57,6 +66,13 @@ impl ServerKeyTrait for ServerKeyEnum { } } + fn packed_not(&self, cts: &Vec<&Ciphertext>) -> Vec { + match self { + ServerKeyEnum::TypeSW(sk) => sk.packed_not(cts), + ServerKeyEnum::TypeFPGA(sk) => sk.packed_not(cts), + } + } + fn trivial_encrypt(&self, value: bool) -> Ciphertext { match self { ServerKeyEnum::TypeSW(sk) => sk.trivial_encrypt(value), diff --git a/tests/Examples/tfhe_rust_bool/fpga/test_fully_connected.mlir b/tests/Examples/tfhe_rust_bool/fpga/test_fully_connected.mlir index 22c7bf47a..dd9341d1f 100644 --- a/tests/Examples/tfhe_rust_bool/fpga/test_fully_connected.mlir +++ b/tests/Examples/tfhe_rust_bool/fpga/test_fully_connected.mlir @@ -1,13 +1,13 @@ // RUN: heir-opt --tosa-to-boolean-fpga-tfhe="abc-fast=true entry-function=fn_under_test" %s | heir-translate --emit-tfhe-rust-bool-packed > %S/src/fn_under_test.rs // RUN: cargo run --release --manifest-path %S/Cargo.toml --bin main_fully_connected -- 2 | FileCheck %s -// This takes takes the input x and outputs 2 \cdot x + 3. +// This takes takes the input x and outputs a FC layer operation. // CHECK: 00000111 module attributes {tf_saved_model.semantics} { - func.func @fn_under_test(%11: tensor<1x1xi8>) -> tensor<1x1xi32> { - %0 = "tosa.const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> - %1 = "tosa.const"() {value = dense<[[2]]> : tensor<1x1xi8>} : () -> tensor<1x1xi8> - %2 = "tosa.fully_connected"(%11, %1, %0) {quantization_info = #tosa.conv_quant} : (tensor<1x1xi8>, tensor<1x1xi8>, tensor<1xi32>) -> tensor<1x1xi32> - return %2 : tensor<1x1xi32> + func.func @fn_under_test(%11: tensor<1x3xi8>) -> tensor<1x3xi32> { + %0 = "tosa.const"() {value = dense<[3, 1, 4]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tosa.const"() {value = dense<[[2, 7, 1], [8,2,8], [1,8,2]]> : tensor<3x3xi8>} : () -> tensor<3x3xi8> + %2 = "tosa.fully_connected"(%11, %1, %0) {quantization_info = #tosa.conv_quant} : (tensor<1x3xi8>, tensor<3x3xi8>, tensor<3xi32>) -> tensor<1x3xi32> + return %2 : tensor<1x3xi32> } }