Skip to content

Commit

Permalink
Tfhe-rs Ops with the scalar definition
Browse files Browse the repository at this point in the history
  • Loading branch information
WoutLegiest committed Jan 28, 2025
1 parent 1f3cf91 commit 203384d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 45 deletions.
36 changes: 4 additions & 32 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.h"

#include <cstdint>
#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/LogicalResult.h>

#include "lib/Dialect/CGGI/IR/CGGIDialect.h"
#include "lib/Dialect/CGGI/IR/CGGIOps.h"
Expand Down Expand Up @@ -48,35 +49,6 @@ class ArithToCGGITypeConverter : public TypeConverter {
}
};

struct ConvertConstantOp : public OpConversionPattern<mlir::arith::ConstantOp> {
ConvertConstantOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::ConstantOp>(context) {}

using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(
mlir::arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<IndexType>(op.getValue().getType())) {
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto intValue = cast<IntegerAttr>(op.getValue()).getValue().getSExtValue();
auto inputValue = mlir::IntegerAttr::get(op.getType(), intValue);

auto encoding = lwe::UnspecifiedBitFieldEncodingAttr::get(
op->getContext(), op.getValue().getType().getIntOrFloatBitWidth());
auto lweType = lwe::LWECiphertextType::get(op->getContext(), encoding,
lwe::LWEParamsAttr());

auto encrypt = b.create<cggi::CreateTrivialOp>(lweType, inputValue);

rewriter.replaceOp(op, encrypt);
return success();
}
};

struct ConvertTruncIOp : public OpConversionPattern<mlir::arith::TruncIOp> {
ConvertTruncIOp(mlir::MLIRContext *context)
: OpConversionPattern<mlir::arith::TruncIOp>(context) {}
Expand Down Expand Up @@ -301,8 +273,8 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
});

patterns.add<
ConvertConstantOp, ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp,
ConvertShRUIOp, ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertTruncIOp, ConvertExtUIOp, ConvertExtSIOp, ConvertShRUIOp,
ConvertArithBinOp<mlir::arith::AddIOp, cggi::AddOp>,
ConvertArithBinOp<mlir::arith::MulIOp, cggi::MulOp>,
ConvertArithBinOp<mlir::arith::SubIOp, cggi::SubOp>,
ConvertAny<memref::LoadOp>, ConvertAny<memref::AllocOp>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ struct ConvertLut3Op : public OpConversionPattern<cggi::Lut3Op> {
serverKey, adaptor.getB(), b.getIndexAttr(1));
auto outputType =
getTypeConverter()->convertType(shiftedB.getResult().getType());
auto summedBC =
b.create<tfhe_rust::AddOp>(outputType, serverKey, shiftedC, shiftedB);
auto summedBC = b.create<tfhe_rust::AddOp>(adaptor.getB().getType(),
serverKey, shiftedC, shiftedB);
auto summedABC = b.create<tfhe_rust::AddOp>(outputType, serverKey, summedBC,
adaptor.getA());

Expand Down
39 changes: 29 additions & 10 deletions lib/Target/TfheRustHL/TfheRustHLEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,12 @@ LogicalResult TfheRustHLEmitter::printOperation(CreateTrivialOp op) {

os << "FheUint" << getTfheRustBitWidth(op.getResult().getType())
<< "::try_encrypt_trivial("
<< variableNames->getNameForValue(op.getValue()) << ").unwrap();\n";
<< variableNames->getNameForValue(op.getValue());

if (op.getValue().getType().isSigned())
os << " as u" << getTfheRustBitWidth(op.getResult().getType());

os << ").unwrap();\n";
return success();
}

Expand Down Expand Up @@ -359,7 +364,12 @@ LogicalResult TfheRustHLEmitter::printOperation(arith::ConstantOp op) {
// By default, it emits an unsigned integer.
emitAssignPrefix(op.getResult());
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
os << intAttr.getValue().abs() << "u64;\n";
os << intAttr.getValue().abs();
if (isa<IndexType>(op.getType())) {
os << "usize;\n";
} else {
os << convertType(op.getType()) << ";\n";
}
} else {
return op.emitError() << "Unknown constant type " << valueAttr.getType();
}
Expand All @@ -383,6 +393,15 @@ LogicalResult TfheRustHLEmitter::printBinaryOp(::mlir::Value result,
std::string_view op) {
emitAssignPrefix(result);

if (auto cteOp = dyn_cast<mlir::arith::ConstantOp>(rhs.getDefiningOp())) {
auto intValue =
cast<IntegerAttr>(cteOp.getValue()).getValue().getZExtValue();
os << checkOrigin(lhs) << variableNames->getNameForValue(lhs) << " " << op
<< " " << intValue << "u" << cteOp.getType().getIntOrFloatBitWidth()
<< ";\n";
return success();
}

os << checkOrigin(lhs) << variableNames->getNameForValue(lhs) << " " << op
<< " " << checkOrigin(rhs) << variableNames->getNameForValue(rhs) << ";\n";
return success();
Expand Down Expand Up @@ -430,8 +449,8 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::AllocOp op) {
if (failed(emitType(op.getMemref().getType().getElementType()))) {
return op.emitOpError() << "Failed to get memref element type";
}

os << "> = BTreeMap::new();\n";

return success();
}

Expand Down Expand Up @@ -463,12 +482,11 @@ LogicalResult TfheRustHLEmitter::printOperation(memref::LoadOp op) {
// We assume here that the indices are SSA values (not integer attributes).
if (isa<BlockArgument>(op.getMemref())) {
emitAssignPrefix(op.getResult());
os << "&" << variableNames->getNameForValue(op.getMemRef()) << "["
<< flattenIndexExpression(op.getMemRefType(), op.getIndices(),
[&](Value value) {
return variableNames->getNameForValue(value);
})
<< "];\n";
os << "&" << variableNames->getNameForValue(op.getMemRef());
for (auto value : op.getIndices()) {
os << "[" << variableNames->getNameForValue(value) << "]";
}
os << ";\n";
return success();
}

Expand Down Expand Up @@ -586,6 +604,7 @@ LogicalResult TfheRustHLEmitter::printOperation(tensor::InsertOp op) {
return std::string(prefix) + variableNames->getNameForValue(value) +
cloneStr;
}) << "];\n";

return success();
}

Expand Down Expand Up @@ -662,7 +681,7 @@ FailureOr<std::string> TfheRustHLEmitter::convertType(Type type) {
}
auto width = getRustIntegerType(type.getWidth());
if (failed(width)) return failure();
return (type.isUnsigned() ? std::string("u") : "") + "i" +
return (type.isSigned() ? std::string("i") : std::string("u")) +
std::to_string(width.value());
})
.Case<LookupTableType>(
Expand Down
2 changes: 2 additions & 0 deletions lib/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

#define DEBUG_TYPE "cggi-to-tfhe-rust"

namespace mlir {
namespace heir {

Expand Down
2 changes: 1 addition & 1 deletion tests/Examples/tfhe_rust_hl/cpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ if you overrode the default option when installing Cargo.

```bash
bazel query "filter('.mlir.test$', //tests/Examples/tfhe_rust_hl/cpu/...)" \
| xargs bazel test --noincompatible_strict_action_env -test_timeout=180 --sandbox_writable_path=$HOME/.cargo "$@"
| xargs bazel test --noincompatible_strict_action_env --test_timeout=180 --sandbox_writable_path=$HOME/.cargo "$@"
```

The `manual` tag is added to the targets in this directory to ensure that they
Expand Down

0 comments on commit 203384d

Please sign in to comment.