diff --git a/include/Dialect/Secret/IR/SecretOps.td b/include/Dialect/Secret/IR/SecretOps.td index 6d7ccedcff..f366d29635 100644 --- a/include/Dialect/Secret/IR/SecretOps.td +++ b/include/Dialect/Secret/IR/SecretOps.td @@ -25,7 +25,7 @@ def Secret_ConcealOp : Secret_Op<"conceal", [Pure]> { Examples: - ``` + ```mlir %Y = secret.conceal %value : i32 -> !secret.secret ``` }]; @@ -55,7 +55,7 @@ def Secret_RevealOp : Secret_Op<"reveal", [Pure]> { Examples: - ``` + ```mlir %Y = secret.reveal %secret_value : !secret.secret -> i32 ``` }]; @@ -104,7 +104,7 @@ def Secret_GenericOp : Secret_Op<"generic", [ Add two secret integers together - ``` + ```mlir %Z = secret.generic ins(%X, %Y : !secret.secret, !secret.secret) { ^bb0(%x: i32, %y: i32) : %z = arith.addi %x, %y: i32 @@ -115,7 +115,7 @@ def Secret_GenericOp : Secret_Op<"generic", [ Add a secret value with a plaintext value. I.e., not all arguments to the op need be secret. - ``` + ```mlir %Z = secret.generic ins(%X, %Y : i32, !secret.secret) { ^bb0(%x: i32, %y: i32) : %z = arith.addi %x, %y: i32 @@ -154,5 +154,35 @@ def Secret_GenericOp : Secret_Op<"generic", [ let hasVerifier = 1; } +def CastOp : Secret_Op<"cast", [Pure]> { + let summary = "A placeholder cast from one secret type to another"; + let description = [{ + A `cast` operation represents a type cast from one secret type to another, + that is used to enable the intermixing of various equivalent secret types + before a lower-level FHE scheme has been chosen. + + For example, `secret.cast` can be used to convert a `secret` to a + `secret>` as a compatibility layer between boolean and + non-boolean parts of a program. The pass that later lowers the IR to + specific FHE schemes would need to replace these casts with appropriate + scheme-specific operations, and it is left to those later passes to + determine which casts are considered valid. + + Example: + + ```mlir + %result = secret.cast %0 : !secret.secret to !secret.secret> + %result2 = secret.cast %0 : !secret.secret to !secret.secret> + ``` + }]; + + let arguments = (ins Secret:$input); + let results = (outs Secret:$output); + let assemblyFormat = [{ + $input attr-dict `:` qualified(type($input)) `to` qualified(type($output)) + }]; + let hasFolder = 1; +} + #endif // HEIR_INCLUDE_DIALECT_SECRET_IR_SECRETOPS_TD_ diff --git a/include/Target/Verilog/BUILD b/include/Target/Verilog/BUILD index e8654d7e0c..39be9e0f8f 100644 --- a/include/Target/Verilog/BUILD +++ b/include/Target/Verilog/BUILD @@ -12,6 +12,7 @@ cc_library( name = "verilog_emitter", hdrs = ["VerilogEmitter.h"], deps = [ + "@heir//lib/Dialect/Secret/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", diff --git a/include/Target/Verilog/VerilogEmitter.h b/include/Target/Verilog/VerilogEmitter.h index ee689860ce..75461940aa 100644 --- a/include/Target/Verilog/VerilogEmitter.h +++ b/include/Target/Verilog/VerilogEmitter.h @@ -1,8 +1,16 @@ #ifndef HEIR_INCLUDE_TARGET_VERILOG_VERILOGEMITTER_H_ #define HEIR_INCLUDE_TARGET_VERILOG_VERILOGEMITTER_H_ -#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include +#include +#include +#include + +#include "include/Dialect/Secret/IR/SecretOps.h" +#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project +#include "llvm/include/llvm/ADT/ilist.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -10,7 +18,12 @@ #include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/Operation.h" // from @llvm-project +#include "mlir/include/mlir/IR/Region.h" // from @llvm-project +#include "mlir/include/mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/Support/IndentedOstream.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { @@ -22,11 +35,26 @@ void registerToVerilogTranslation(); mlir::LogicalResult translateToVerilog(mlir::Operation *op, llvm::raw_ostream &os); +/// Translates the given operation to Verilog with a fixed input name for the +/// resulting verilog module. Raises an error if the input IR contains secret +/// ops. +mlir::LogicalResult translateToVerilog( + mlir::Operation *op, llvm::raw_ostream &os, + std::optional moduleName); + +/// Translates the given operation to Verilog with a fixed input name for the +/// resulting verilog module. If allowSecretOps is false, raises an error if +/// the input IR contains secret ops. +mlir::LogicalResult translateToVerilog( + mlir::Operation *op, llvm::raw_ostream &os, + std::optional moduleName, bool allowSecretOps); + class VerilogEmitter { public: VerilogEmitter(raw_ostream &os); - LogicalResult translate(mlir::Operation &operation); + LogicalResult translate(mlir::Operation &operation, + std::optional moduleName); private: /// Output stream to emit to. @@ -38,8 +66,26 @@ class VerilogEmitter { // Globally unique identifiers for values int64_t value_count_; + // A helper to generalize the work of emitting a func.func and a + // secret.generic + LogicalResult printFunctionLikeOp(Operation *op, + llvm::StringRef verilogModuleName, + ArrayRef arguments, + TypeRange resultTypes, + Region::BlockListType::iterator blocksBegin, + Region::BlockListType::iterator blocksEnd); + + // A helper to generalize the work of emitting a func.return and a + // secret.yield + LogicalResult printReturnLikeOp(Value returnValue); + // Functions for printing individual ops - LogicalResult printOperation(mlir::ModuleOp op); + LogicalResult printOperation(mlir::ModuleOp op, + std::optional moduleName); + LogicalResult printOperation(mlir::func::FuncOp op, + std::optional moduleName); + LogicalResult printOperation(mlir::heir::secret::GenericOp op, + std::optional moduleName); LogicalResult printOperation(mlir::UnrealizedConversionCastOp op); LogicalResult printOperation(mlir::arith::AddIOp op); LogicalResult printOperation(mlir::arith::AndIOp op); @@ -59,8 +105,8 @@ class VerilogEmitter { LogicalResult printOperation(mlir::affine::AffineLoadOp op); LogicalResult printOperation(mlir::affine::AffineStoreOp op); LogicalResult printOperation(mlir::func::CallOp op); - LogicalResult printOperation(mlir::func::FuncOp op); LogicalResult printOperation(mlir::func::ReturnOp op); + LogicalResult printOperation(mlir::heir::secret::YieldOp op); LogicalResult printOperation(mlir::math::CountLeadingZerosOp op); LogicalResult printOperation(mlir::memref::LoadOp op); @@ -69,7 +115,7 @@ class VerilogEmitter { mlir::Value rhs, std::string_view op); // Emit a Verilog type of the form `wire [width-1:0]` - LogicalResult emitType(Location loc, Type type); + LogicalResult emitType(Type type); // Emit a Verilog array shape specifier of the form `[width]` LogicalResult emitArrayShapeSuffix(Type type); diff --git a/include/Transforms/YosysOptimizer/YosysOptimizer.h b/include/Transforms/YosysOptimizer/YosysOptimizer.h index 9b460d8f7d..ca9987622d 100644 --- a/include/Transforms/YosysOptimizer/YosysOptimizer.h +++ b/include/Transforms/YosysOptimizer/YosysOptimizer.h @@ -21,8 +21,8 @@ struct YosysOptimizerPipelineOptions // registerYosysOptimizerPipeline registers a Yosys pipeline pass using // runfiles, the location of Yosys techlib files, and abcPath, the location of // the abc binary. -void registerYosysOptimizerPipeline(std::string yosysFilesPath, - std::string abcPath); +void registerYosysOptimizerPipeline(const std::string &yosysFilesPath, + const std::string &abcPath); } // namespace heir } // namespace mlir diff --git a/include/Transforms/YosysOptimizer/YosysOptimizer.td b/include/Transforms/YosysOptimizer/YosysOptimizer.td index 51749c00c1..c00341c4a5 100644 --- a/include/Transforms/YosysOptimizer/YosysOptimizer.td +++ b/include/Transforms/YosysOptimizer/YosysOptimizer.td @@ -18,6 +18,7 @@ def YosysOptimizer : Pass<"yosys-optimizer"> { let dependentDialects = [ "mlir::arith::ArithDialect", "mlir::heir::comb::CombDialect", + "mlir::heir::secret::SecretDialect", "mlir::tensor::TensorDialect" ]; } diff --git a/lib/Dialect/Secret/IR/SecretOps.cpp b/lib/Dialect/Secret/IR/SecretOps.cpp index f66c30241d..16088331b7 100644 --- a/lib/Dialect/Secret/IR/SecretOps.cpp +++ b/lib/Dialect/Secret/IR/SecretOps.cpp @@ -234,6 +234,28 @@ void GenericOp::build(OpBuilder &builder, OperationState &result, bodyBuilder(builder, result.location, bodyBlock.getArguments()); } +OpFoldResult CastOp::fold(CastOp::FoldAdaptor adaptor) { + Value input = getInput(); + Value output = getOutput(); + + // self cast is a no-op + if (input.getType() == output.getType()) { + return input; + } + + // Fold a cast-and-cast-back to a no-op. + // + // %1 = secret.cast %0 : !secret.secret to !secret.secret + // %2 = secret.cast %1 : !secret.secret to !secret.secret + // + // folds to use %0 directly in place of %2. + auto inputOp = input.getDefiningOp(); + if (!inputOp || output.getType() != inputOp.getInput().getType()) + return OpFoldResult(); + + return inputOp.getInput(); +} + } // namespace secret } // namespace heir } // namespace mlir diff --git a/lib/Target/Verilog/BUILD b/lib/Target/Verilog/BUILD index e0ba12a0b5..20c6f7babc 100644 --- a/lib/Target/Verilog/BUILD +++ b/lib/Target/Verilog/BUILD @@ -12,6 +12,7 @@ cc_library( deps = [ "@heir//include/Target/Verilog:verilog_emitter", "@heir//lib/Conversion/MemrefToArith:Utils", + "@heir//lib/Dialect/Secret/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineAnalysis", "@llvm-project//mlir:AffineDialect", diff --git a/lib/Target/Verilog/VerilogEmitter.cpp b/lib/Target/Verilog/VerilogEmitter.cpp index f72b453bb4..3129028ebc 100644 --- a/lib/Target/Verilog/VerilogEmitter.cpp +++ b/lib/Target/Verilog/VerilogEmitter.cpp @@ -1,17 +1,40 @@ #include "include/Target/Verilog/VerilogEmitter.h" +#include +#include +#include +#include +#include +#include + #include "include/Conversion/MemrefToArith/Utils.h" +#include "include/Dialect/Secret/IR/SecretDialect.h" +#include "include/Dialect/Secret/IR/SecretOps.h" +#include "include/Dialect/Secret/IR/SecretTypes.h" +#include "llvm/include/llvm/ADT/SmallString.h" // from @llvm-project +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project +#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project +#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project +#include "llvm/include/llvm/ADT/ilist.h" // from @llvm-project +#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Affine/Analysis/AffineAnalysis.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Math/IR/Math.h" // from @llvm-project #include "mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project +#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-translate/Translation.h" // from @llvm-project @@ -28,7 +51,6 @@ static constexpr std::string_view kOutputName = "_out_"; bool shouldMapToSigned(IntegerType::SignednessSemantics val) { switch (val) { case IntegerType::Signless: - return true; case IntegerType::Signed: return true; case IntegerType::Unsigned: @@ -106,25 +128,49 @@ void registerToVerilogTranslation() { [](DialectRegistry ®istry) { registry.insert(); + secret::SecretDialect, math::MathDialect>(); }); } -LogicalResult translateToVerilog(Operation *op, llvm::raw_ostream &os) { +LogicalResult translateToVerilog(Operation *op, llvm::raw_ostream &os, + std::optional moduleName) { + return translateToVerilog(op, os, moduleName, /*allowSecretOps=*/false); +} + +LogicalResult translateToVerilog(Operation *op, llvm::raw_ostream &os, + std::optional moduleName, + bool allowSecretOps) { + if (!allowSecretOps) { + auto result = op->walk([&](Operation *op) -> WalkResult { + if (isa(op->getDialect())) { + op->emitError("allowSecretOps is false, but encountered a secret op."); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (result.wasInterrupted()) return failure(); + } + VerilogEmitter emitter(os); - LogicalResult result = emitter.translate(*op); + LogicalResult result = emitter.translate(*op, moduleName); return result; } +LogicalResult translateToVerilog(Operation *op, llvm::raw_ostream &os) { + return translateToVerilog(op, os, std::nullopt); +} + VerilogEmitter::VerilogEmitter(raw_ostream &os) : os_(os), value_count_(0) {} -LogicalResult VerilogEmitter::translate(Operation &op) { +LogicalResult VerilogEmitter::translate( + Operation &op, std::optional moduleName) { LogicalResult status = llvm::TypeSwitch(op) - // Builtin ops. - .Case([&](auto op) { return printOperation(op); }) + // Ops that use moduleName + .Case( + [&](auto op) { return printOperation(op, moduleName); }) // Func ops. - .Case( + .Case( [&](auto op) { return printOperation(op); }) // Arithmetic ops. .Case([&](auto op) { @@ -168,9 +214,7 @@ LogicalResult VerilogEmitter::translate(Operation &op) { }) .Case([&](auto op) { return printOperation(op); }) // Affine ops. - .Case( - [&](auto op) { return printOperation(op); }) - .Case( + .Case( [&](auto op) { return printOperation(op); }) .Case( [&](auto op) { return printOperation(op); }) @@ -185,11 +229,12 @@ LogicalResult VerilogEmitter::translate(Operation &op) { return success(); } -LogicalResult VerilogEmitter::printOperation(ModuleOp moduleOp) { +LogicalResult VerilogEmitter::printOperation( + ModuleOp moduleOp, std::optional moduleName) { // We have no use in separating things by modules, so just descend // to the underlying ops and continue. for (Operation &op : moduleOp) { - if (failed(translate(op))) { + if (failed(translate(op, moduleName))) { return failure(); } } @@ -197,7 +242,11 @@ LogicalResult VerilogEmitter::printOperation(ModuleOp moduleOp) { return success(); } -LogicalResult VerilogEmitter::printOperation(func::FuncOp funcOp) { +LogicalResult VerilogEmitter::printFunctionLikeOp( + Operation *op, llvm::StringRef verilogModuleName, + ArrayRef arguments, TypeRange resultTypes, + Region::BlockListType::iterator blocksBegin, + Region::BlockListType::iterator blocksEnd) { /* * A func op translates as follows, noting the internal variable wires * need to be defined at the beginning of the module. @@ -214,26 +263,27 @@ LogicalResult VerilogEmitter::printOperation(func::FuncOp funcOp) { * ... * endmodule */ - os_ << "module " << funcOp.getName() << "(\n"; + os_ << "module " << verilogModuleName << "(\n"; os_.indent(); - for (auto arg : funcOp.getArguments()) { + for (auto arg : arguments) { // e.g., `input wire [31:0] arg0,` os_ << "input "; - if (failed(emitType(arg.getLoc(), arg.getType()))) { + if (failed(emitType(arg.getType()))) { + op->emitError() << "failed to emit type" << arg.getType(); return failure(); } os_ << " " << getOrCreateName(arg) << ",\n"; } // output arg declaration - auto result_types = funcOp.getFunctionType().getResults(); - if (result_types.size() != 1) { - emitError(funcOp.getLoc(), + if (resultTypes.size() != 1) { + emitError(op->getLoc(), "Only functions with a single return type are supported"); return failure(); } os_ << "output "; - if (failed(emitType(funcOp.getLoc(), result_types.front()))) { + if (failed(emitType(resultTypes.front()))) { + op->emitError() << "failed to emit type" << resultTypes.front(); return failure(); } os_ << " " << kOutputName; @@ -248,17 +298,17 @@ LogicalResult VerilogEmitter::printOperation(func::FuncOp funcOp) { // Wire declarations. // Look for any op outputs, which are interleaved throughout the function // body. Collect any globals used. - llvm::SmallVector get_globals; + llvm::SmallVector getGlobals; WalkResult result = - funcOp.walk([&](Operation *op) -> WalkResult { + op->walk([&](Operation *op) -> WalkResult { if (auto globalOp = dyn_cast(op)) { - get_globals.push_back(globalOp); + getGlobals.push_back(globalOp); } if (auto indexCastOp = dyn_cast(op)) { - // IndexCastOp's are a layer of indirection in the arithmetic dialect - // that is unneeded in Verilog. A wire declaration is not needed. - // Simply remove the indirection by adding a map from the index-casted - // result value to the input integer value. + // IndexCastOp's are a layer of indirection in the arithmetic + // dialect that is unneeded in Verilog. A wire declaration is not + // needed. Simply remove the indirection by adding a map from the + // index-casted result value to the input integer value. auto retVal = indexCastOp.getResult(); if (!value_to_wire_name_.contains(retVal)) { value_to_wire_name_.insert(std::make_pair( @@ -284,13 +334,14 @@ LogicalResult VerilogEmitter::printOperation(func::FuncOp funcOp) { } for (OpResult result : op->getResults()) { if (failed(emitWireDeclaration(result))) { - return WalkResult( - op->emitError("unable to declare result variable for op")); + return WalkResult(op->emitError() + << "unable to declare result variable of type " + << result.getType()); } } // Also generate intermediate result values the CTLZ computation. if (auto ctlzOp = dyn_cast(op)) { - auto ctx = op->getContext(); + auto *ctx = op->getContext(); auto ctlzStruct = ctlzStructForResult(getOrCreateName(ctlzOp.getResult())); llvm::SmallVector, 4> tempWires = { @@ -299,8 +350,7 @@ LogicalResult VerilogEmitter::printOperation(func::FuncOp funcOp) { {ctlzStruct.temp8, 8}, {ctlzStruct.temp4, 4}}; for (auto tempWire : tempWires) { - if (failed(emitType(op->getLoc(), - IntegerType::get(ctx, tempWire.second)))) { + if (failed(emitType(IntegerType::get(ctx, tempWire.second)))) { return failure(); } os_ << " " << tempWire.first << ";\n"; @@ -310,12 +360,12 @@ LogicalResult VerilogEmitter::printOperation(func::FuncOp funcOp) { }); if (result.wasInterrupted()) return failure(); - auto module = funcOp->getParentOfType(); + auto module = op->getParentOfType(); assert(module); // Assign global values while we have access to the top-level module. - if (!get_globals.empty()) { - for (memref::GetGlobalOp getGlobalOp : get_globals) { + if (!getGlobals.empty()) { + for (memref::GetGlobalOp getGlobalOp : getGlobals) { auto global = cast( module.lookupSymbol(getGlobalOp.getNameAttr())); auto cstAttr = @@ -331,32 +381,48 @@ LogicalResult VerilogEmitter::printOperation(func::FuncOp funcOp) { } os_ << "\n"; - // ops - for (Block &block : funcOp.getBlocks()) { - for (Operation &op : block.getOperations()) { - if (failed(translate(op))) { + while (blocksBegin != blocksEnd) { + for (Operation &op : blocksBegin->getOperations()) { + if (failed(translate(op, std::nullopt))) { return failure(); } } + blocksBegin++; } os_.unindent(); os_ << "endmodule\n"; return success(); } -LogicalResult VerilogEmitter::printOperation(func::ReturnOp op) { +LogicalResult VerilogEmitter::printOperation( + func::FuncOp funcOp, std::optional moduleName) { + auto *blocks = &funcOp.getBlocks(); + return printFunctionLikeOp( + funcOp.getOperation(), moduleName.value_or(funcOp.getName()), + funcOp.getArguments(), funcOp.getFunctionType().getResults(), + blocks->begin(), blocks->end()); +} + +LogicalResult VerilogEmitter::printReturnLikeOp(Value returnValue) { // Return is an assignment to the output wire // e.g., assign out = x1200; + os_ << "assign " << kOutputName << " = " << getName(returnValue) << ";\n"; + return success(); +} +LogicalResult VerilogEmitter::printOperation(func::ReturnOp op) { // Only support one return value. - auto retval = op.getOperands()[0]; - os_ << "assign " << kOutputName << " = " << getName(retval) << ";\n"; - return success(); + return printReturnLikeOp(op.getOperands()[0]); +} + +LogicalResult VerilogEmitter::printOperation(secret::YieldOp op) { + // Only support one return value. + return printReturnLikeOp(op.getOperands()[0]); } LogicalResult VerilogEmitter::printOperation(func::CallOp op) { // e.g., submodule submod_call(xInput0, xInput1, xOutput); - auto opName = getOrCreateName(op.getResult(0)) + "_call"; + std::string opName = (getOrCreateName(op.getResult(0)) + "_call").str(); // Verilog only supports functions with a single return value. if (op.getResults().size() != 1) { @@ -388,9 +454,10 @@ LogicalResult VerilogEmitter::printOperation(arith::AndIOp op) { LogicalResult VerilogEmitter::printOperation(arith::CmpIOp op) { switch (op.getPredicate()) { - // For eq and ne, verilog has multiple operators. == and === are equivalent, - // except for the special values X (unknown default initial state) and Z - // (high impedance state), which are irrelevant for our purposes. Ditto for + // For eq and ne, verilog has multiple operators. == and === are + // equivalent, except for the special values X (unknown default initial + // state) and Z (high impedance state), which are irrelevant for our + // purposes. Ditto for // != and !==. case arith::CmpIPredicate::eq: return printBinaryOp(op.getResult(), op.getLhs(), op.getRhs(), "=="); @@ -651,11 +718,33 @@ LogicalResult VerilogEmitter::printOperation(math::CountLeadingZerosOp op) { return success(); } -LogicalResult VerilogEmitter::emitType(Location loc, Type type) { +LogicalResult VerilogEmitter::printOperation( + mlir::heir::secret::GenericOp op, + std::optional moduleName) { + llvm::StringRef name; + if (moduleName.has_value()) { + name = moduleName.value(); + } else { + // I wanted something more unique here, but an op.getLoc() doesn't print + // as a valid verilog identifier. Maybe with enough string massaging it + // could. + name = "generic_body"; + } + llvm::SmallVector resultTypes; + for (auto ty : op.getResultTypes()) + resultTypes.push_back(cast(ty).getValueType()); + auto *blocks = &op.getRegion().getBlocks(); + return printFunctionLikeOp(op.getOperation(), name, + op.getRegion().getBlocks().front().getArguments(), + resultTypes, blocks->begin(), blocks->end()); +} + +LogicalResult VerilogEmitter::emitType(Type type) { if (auto iType = dyn_cast(type)) { int32_t width = iType.getWidth(); return (os_ << wireDeclaration(iType, width)), success(); - } else if (auto memRefType = dyn_cast(type)) { + } + if (auto memRefType = dyn_cast(type)) { auto elementType = memRefType.getElementType(); if (auto iType = dyn_cast(elementType)) { int32_t flattenedWidth = memRefType.getNumElements() * iType.getWidth(); @@ -670,7 +759,10 @@ void VerilogEmitter::emitAssignPrefix(Value result) { } LogicalResult VerilogEmitter::emitWireDeclaration(OpResult result) { - if (failed(emitType(result.getLoc(), result.getType()))) { + Type ty = result.getType(); + if (ty.isa()) + ty = cast(ty).getValueType(); + if (failed(emitType(ty))) { return failure(); } os_ << " " << getOrCreateName(result) << ";\n"; diff --git a/lib/Transforms/YosysOptimizer/BUILD b/lib/Transforms/YosysOptimizer/BUILD index e4ffebe051..b9557ee3d9 100644 --- a/lib/Transforms/YosysOptimizer/BUILD +++ b/lib/Transforms/YosysOptimizer/BUILD @@ -76,9 +76,9 @@ cc_library( ":LUTImporter", ":RTLILImporter", "@at_clifford_yosys//:kernel", - "@at_clifford_yosys//:version", "@heir//include/Transforms/YosysOptimizer:pass_inc_gen", "@heir//lib/Dialect/Comb/IR:Dialect", + "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Target/Verilog:VerilogEmitter", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", diff --git a/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp b/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp index 8d92d03963..28bb1b0502 100644 --- a/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp +++ b/lib/Transforms/YosysOptimizer/YosysOptimizer.cpp @@ -1,6 +1,6 @@ #include "include/Transforms/YosysOptimizer/YosysOptimizer.h" -#include +#include #include #include #include @@ -10,19 +10,29 @@ #include #include "include/Dialect/Comb/IR/CombDialect.h" +#include "include/Dialect/Secret/IR/SecretOps.h" +#include "include/Dialect/Secret/IR/SecretTypes.h" #include "include/Target/Verilog/VerilogEmitter.h" #include "lib/Transforms/YosysOptimizer/LUTImporter.h" #include "lib/Transforms/YosysOptimizer/RTLILImporter.h" +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/include/mlir/IR/Location.h" // from @llvm-project +#include "mlir/include/mlir/IR/Types.h" // from @llvm-project +#include "mlir/include/mlir/IR/Value.h" // from @llvm-project +#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project @@ -35,6 +45,7 @@ namespace mlir { namespace heir { +using std::string; #define GEN_PASS_DEF_YOSYSOPTIMIZER #include "include/Transforms/YosysOptimizer/YosysOptimizer.h.inc" @@ -79,63 +90,231 @@ struct YosysOptimizer : public impl::YosysOptimizerBase { bool abcFast; }; -// Globally optimize an MLIR module. -void YosysOptimizer::runOnOperation() { - getOperation()->walk([&](func::FuncOp op) { - // Translate function to Verilog. Translation will fail if the func - // contains unsupported operations. - // TODO(https://github.com/google/heir/issues/111): Directly convert MLIR to - // Yosys' AST instead of using Verilog. - char *filename = std::tmpnam(NULL); - std::error_code EC; - llvm::raw_fd_ostream of(filename, EC); - if (failed(translateToVerilog(op, of)) || EC) { - return WalkResult::interrupt(); +tensor::FromElementsOp convertIntegerValue(Value value, Type convertedType, + OpBuilder &b, Location loc) { + IntegerType argType = value.getType().cast(); + int width = argType.getWidth(); + SmallVector extractedBits; + extractedBits.reserve(width); + + for (int i = 0; i < width; i++) { + // These arith ops correspond to extracting the i-th bit + // from the input + auto shiftAmount = + b.create(loc, argType, b.getIntegerAttr(argType, i)); + auto bitMask = b.create( + loc, argType, b.getIntegerAttr(argType, 1 << i)); + auto andOp = b.create(loc, value, bitMask); + auto shifted = b.create(loc, andOp, shiftAmount); + extractedBits.push_back( + b.create(loc, b.getI1Type(), shifted)); + } + + return b.create(loc, convertedType, + ValueRange{extractedBits}); +} + +/// Convert a secret.generic's operands secret.secret +/// to secret.secret>. +LogicalResult convertOpOperands(secret::GenericOp op, func::FuncOp func, + SmallVector &typeConvertedArgs) { + for (OpOperand &opOperand : op->getOpOperands()) { + Type convertedType = + func.getFunctionType().getInputs()[opOperand.getOperandNumber()]; + + if (!opOperand.get().getType().isa()) { + // The type is not secret, but still must be booleanized + OpBuilder builder(op); + auto fromElementsOp = convertIntegerValue(opOperand.get(), convertedType, + builder, op.getLoc()); + typeConvertedArgs.push_back(fromElementsOp.getResult()); + continue; + } + + secret::SecretType originalType = + opOperand.get().getType().cast(); + if (!originalType.getValueType().isa()) { + op.emitError() << "Unsupported input type to secret.generic: " + << originalType.getValueType(); + return failure(); } + + // Insert a conversion from the original type to the converted type + OpBuilder builder(op); + typeConvertedArgs.push_back(builder.create( + op.getLoc(), secret::SecretType::get(convertedType), opOperand.get())); + } + + return success(); +} + +/// Convert a secret.generic's results from secret.secret> +/// to secret.secret. +LogicalResult convertOpResults(secret::GenericOp op, + DenseSet &castOps, + SmallVector &typeConvertedResults) { + for (Value opResult : op.getResults()) { + // The secret.yield verifier ensures generic can only return secret types. + assert(opResult.getType().isa()); + RankedTensorType convertedType = opResult.getType() + .cast() + .getValueType() + .cast(); + if (!convertedType.getElementType().isa() || + convertedType.getRank() != 1) { + op.emitError() << "While booleanizing secret.generic, found converted " + "type that cannot be reassembled: " + << convertedType; + return failure(); + } + + IntegerType elementType = + convertedType.getElementType().cast(); + if (elementType.getWidth() != 1) { + op.emitError() << "Converted element type must be i1"; + return failure(); + } + + IntegerType reassembledType = + IntegerType::get(op.getContext(), elementType.getWidth() * + convertedType.getNumElements()); + + // Insert a reassembly of the original integer type from its booleanized + // tensor version. + OpBuilder builder(op); + builder.setInsertionPointAfter(op); + auto castOp = builder.create( + op.getLoc(), secret::SecretType::get(reassembledType), opResult); + castOps.insert(castOp); + typeConvertedResults.push_back(castOp.getOutput()); + } + + return success(); +} + +LogicalResult runOnGenericOp(MLIRContext *context, secret::GenericOp op, + const std::string &yosysFilesPath, + const std::string &abcPath, bool abcFast) { + std::string moduleName = "generic_body"; + + // Translate function to Verilog. Translation will fail if the func contains + // unsupported operations. + // TODO(https://github.com/google/heir/issues/111): Directly convert MLIR to + // Yosys' AST instead of using Verilog. + // + // After that is done, it might make sense to rewrite this as a + // RewritePattern, which only runs if the body does not contain any comb ops, + // and generalize this to support converting a secret.generic as well as a + // func.func. It's necessary to wait for the migration because the Yosys API + // used here maintains global state that apparently does not play nicely with + // the instantiation of multiple rewrite patterns. + char *filename = std::tmpnam(nullptr); + std::error_code ec; + llvm::raw_fd_ostream of(filename, ec); + if (failed(translateToVerilog(op, of, moduleName, + /*allowSecretOps=*/true)) || + ec) { + op.emitError() << "Failed to translate to verilog"; of.close(); + return failure(); + } + of.close(); + + // Invoke Yosys to translate to a combinational circuit and optimize. + Yosys::yosys_setup(); + Yosys::log_error_stderr = true; + LLVM_DEBUG(Yosys::log_streams.push_back(&std::cout)); + Yosys::run_pass(llvm::formatv(kYosysTemplate.data(), filename, moduleName, + yosysFilesPath, abcPath, + abcFast ? "-fast" : "")); + + // Translate Yosys result back to MLIR and insert into the func + LLVM_DEBUG(Yosys::run_pass("dump;")); + std::stringstream cellOrder; + Yosys::log_streams.push_back(&cellOrder); + Yosys::run_pass("torder -stop * P*;"); + Yosys::log_streams.clear(); + auto topologicalOrder = getTopologicalOrder(cellOrder); + LUTImporter lutImporter = LUTImporter(context); + Yosys::RTLIL::Design *design = Yosys::yosys_get_design(); + func::FuncOp func = + lutImporter.importModule(design->top_module(), topologicalOrder); + Yosys::yosys_shutdown(); + + // The pass changes the yielded value types, e.g., from an i8 to a + // tensor<8xi1>. So the containing secret.generic needs to be updated and + // conversions implemented on either side to convert the ints to tensors + // and back again. + // + // convertOpOperands goes from i8 -> tensor.tensor<8xi1> + // converOpResults from tensor.tensor<8xi1> -> i8 + SmallVector typeConvertedArgs; + typeConvertedArgs.reserve(op->getNumOperands()); + if (failed(convertOpOperands(op, func, typeConvertedArgs))) { + return failure(); + } + + int resultIndex = 0; + for (Type ty : func.getFunctionType().getResults()) + op->getResult(resultIndex++).setType(secret::SecretType::get(ty)); + + // Replace the func.return with a secret.yield + op.getRegion().takeBody(func.getBody()); + op.getOperation()->setOperands(typeConvertedArgs); + + Block &block = op.getRegion().getBlocks().front(); + func::ReturnOp returnOp = cast(block.getTerminator()); + OpBuilder bodyBuilder(&block, block.end()); + bodyBuilder.create(returnOp.getLoc(), + returnOp.getOperands()); + returnOp.erase(); + func.erase(); - // Invoke Yosys to translate to a combinational circuit and optimize. - Yosys::yosys_setup(); - Yosys::log_error_stderr = true; - LLVM_DEBUG(Yosys::log_streams.push_back(&std::cout)); - Yosys::run_pass(llvm::formatv(kYosysTemplate.data(), filename, - op.getSymName(), yosysFilesPath, abcPath, - abcFast ? "-fast" : "")); - - // Translate to MLIR and insert into the func - LLVM_DEBUG(Yosys::run_pass("dump;")); - std::stringstream cellOrder; - Yosys::log_streams.push_back(&cellOrder); - Yosys::run_pass("torder -stop * P*;"); - Yosys::log_streams.clear(); - auto topologicalOrder = getTopologicalOrder(cellOrder); - - // Insert the optimized MLIR. - LUTImporter lutImporter = LUTImporter(&getContext()); - Yosys::RTLIL::Design *design = Yosys::yosys_get_design(); - func::FuncOp func = - lutImporter.importModule(design->top_module(), topologicalOrder); - - LLVM_DEBUG(llvm::dbgs() - << "Converted & optimized func via yosys. Input func:\n" - << op << "\n\nOutput func:\n" - << func << "\n"); - op.setFunctionType(func.getFunctionType()); - op.getBody().takeBody(func.getBody()); + DenseSet castOps; + SmallVector typeConvertedResults; + castOps.reserve(op->getNumResults()); + typeConvertedResults.reserve(op->getNumResults()); + if (failed(convertOpResults(op, castOps, typeConvertedResults))) { + return failure(); + } + + LLVM_DEBUG(llvm::dbgs() << "Generic results: " << typeConvertedResults.size() + << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Original results: " << op.getResults().size() + << "\n"); + + op.getResults().replaceUsesWithIf( + typeConvertedResults, [&](OpOperand &operand) { + return !castOps.contains(operand.getOwner()); + }); + return success(); +} +// Optimize the body of a secret.generic op. +// FIXME: consider utilizing +// https://mlir.llvm.org/docs/PassManagement/#dynamic-pass-pipelines +void YosysOptimizer::runOnOperation() { + auto result = getOperation()->walk([&](secret::GenericOp op) { + if (failed(runOnGenericOp(&getContext(), op, yosysFilesPath, abcPath, + abcFast))) { + return WalkResult::interrupt(); + } return WalkResult::advance(); }); - Yosys::yosys_shutdown(); + + if (result.wasInterrupted()) { + signalPassFailure(); + } } -std::unique_ptr createYosysOptimizer(std::string yosysFilesPath, - std::string abcPath, - bool abcFast) { +std::unique_ptr createYosysOptimizer( + const std::string &yosysFilesPath, const std::string &abcPath, + bool abcFast) { return std::make_unique(yosysFilesPath, abcPath, abcFast); } -void registerYosysOptimizerPipeline(std::string yosysFilesPath, - std::string abcPath) { +void registerYosysOptimizerPipeline(const std::string &yosysFilesPath, + const std::string &abcPath) { PassPipelineRegistration( "yosys-optimizer", "The yosys optimizer pipeline.", [yosysFilesPath, abcPath](OpPassManager &pm, diff --git a/tests/verilog/emit_verilog_errors.mlir b/tests/verilog/emit_verilog_errors.mlir new file mode 100644 index 0000000000..50adcfd3be --- /dev/null +++ b/tests/verilog/emit_verilog_errors.mlir @@ -0,0 +1,16 @@ +// RUN: heir-translate --emit-verilog --verify-diagnostics %s + +module { + func.func @add_one(%in: !secret.secret) -> (!secret.secret) { + %one = arith.constant 1 : i8 + %1 = secret.generic + ins(%in, %one: !secret.secret, i8) { + ^bb0(%IN: i8, %ONE: i8) : + %2 = arith.addi %IN, %ONE : i8 + // The error is on yield because MLIR walks the IR in preorder traversal. + // expected-error@+1 {{allowSecretOps is false, but encountered a secret op.}} + secret.yield %2 : i8 + } -> (!secret.secret) + return %1 : !secret.secret + } +} diff --git a/tests/yosys_optimizer/add_one.mlir b/tests/yosys_optimizer/add_one.mlir index 7d260c3266..42e64ee4b6 100644 --- a/tests/yosys_optimizer/add_one.mlir +++ b/tests/yosys_optimizer/add_one.mlir @@ -1,14 +1,24 @@ // RUN: heir-opt --yosys-optimizer %s | FileCheck %s -// CHECK: module module { - func.func @add_one(%in: i8) -> (i8) { - // CHECK: comb.truth_table - %0 = arith.constant 1 : i8 - // CHECK-NOT arith.addi - %1 = arith.addi %in, %0 : i8 - // CHECK: tensor.from_elements - // CHECK-NEXT: return - return %1 : i8 + // CHECK-LABEL: @add_one + func.func @add_one(%in: !secret.secret) -> (!secret.secret) { + %one = arith.constant 1 : i8 + // Generic to convert the i8 to a tensor + // CHECK: secret.cast + // CHECK-SAME: !secret.secret to !secret.secret> + + // CHECK: secret.generic + %1 = secret.generic + ins(%in, %one: !secret.secret, i8) { + ^bb0(%IN: i8, %ONE: i8) : + // CHECK-NOT: arith.addi + %2 = arith.addi %IN, %ONE : i8 + secret.yield %2 : i8 + } -> (!secret.secret) + + // CHECK: secret.cast + // CHECK-SAME: !secret.secret> to !secret.secret + return %1 : !secret.secret } } diff --git a/tests/yosys_optimizer/arith_ops.mlir b/tests/yosys_optimizer/arith_ops.mlir index b00050ddf7..10c7739ab9 100644 --- a/tests/yosys_optimizer/arith_ops.mlir +++ b/tests/yosys_optimizer/arith_ops.mlir @@ -1,14 +1,36 @@ // RUN: heir-opt --yosys-optimizer %s | FileCheck %s -// CHECK: module -module { - func.func @ops(i32, i32, i32, i32) -> (i32) { - ^bb0(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32): - // CHECK: comb.truth_table - %0 = arith.subi %arg0, %arg1: i32 - %1 = arith.muli %arg2, %arg3 : i32 - %2 = arith.andi %1, %arg3 : i32 - // CHECK: return - return %2 : i32 - } +// CHECK-LABEL: @ops +func.func @ops( + %arg0: !secret.secret, + %arg1: !secret.secret, + %arg2: !secret.secret, + %arg3: !secret.secret) -> (!secret.secret) { + %1 = secret.generic ins(%arg0, %arg1, %arg2, %arg3: + !secret.secret, !secret.secret, !secret.secret, !secret.secret) { + ^bb0(%a0: i3, %a1: i3, %a2: i3, %a3: i3): + %0 = arith.subi %a0, %a1: i3 + %1 = arith.muli %a2, %a3: i3 + %2 = arith.andi %1, %a3: i3 + secret.yield %2 : i3 + } -> (!secret.secret) + return %1 : !secret.secret + // CHECK: secret.cast + // CHECK-SAME: !secret.secret to !secret.secret> + // CHECK: secret.cast + // CHECK-SAME: !secret.secret to !secret.secret> + // CHECK: secret.cast + // CHECK-SAME: !secret.secret to !secret.secret> + // CHECK: secret.cast + // CHECK-SAME: !secret.secret to !secret.secret> + + // Main computation + // CHECK: secret.generic + // CHECK-COUNT-7: comb.truth_table + // CHECK: secret.yield + // CHECK-SAME: tensor<3xi1> + + // CHECK: secret.cast + // CHECK-SAME: !secret.secret> to !secret.secret + // CHECK: return } diff --git a/tests/yosys_optimizer/chunk_connections.mlir b/tests/yosys_optimizer/chunk_connections.mlir index 65f72e3598..9c857eaedb 100644 --- a/tests/yosys_optimizer/chunk_connections.mlir +++ b/tests/yosys_optimizer/chunk_connections.mlir @@ -1,11 +1,19 @@ // RUN: heir-opt -yosys-optimizer %s | FileCheck %s - // CHECK-LABEL: @for - func.func @for_loop(%arg0: i8, %arg1: i8) -> i32 { - // CHECK-NOT: arith.extsi - // CHECK-NOT: arith.subi - // CHECK-NOT: arith.muli - // CHECK-NOT: arith.addi +// CHECK-LABEL: @for_loop +func.func @for_loop(%ARG0: !secret.secret, %ARG1: !secret.secret) -> !secret.secret { + // convert two ARGs + // CHECK: secret.cast + // CHECK: secret.cast + + // CHECK: secret.generic + // CHECK-NOT: arith.extsi + // CHECK-NOT: arith.subi + // CHECK-NOT: arith.muli + // CHECK-NOT: arith.addi + %1 = secret.generic + ins(%ARG0, %ARG1: !secret.secret, !secret.secret) { + ^bb0(%arg0: i8, %arg1: i8) : %c-128_i16 = arith.constant -128 : i16 %c0_i32 = arith.constant 0 : i32 %0 = arith.extsi %arg0 : i8 to i16 @@ -14,6 +22,10 @@ %3 = arith.extsi %arg1 : i8 to i32 %4 = arith.muli %2, %3 : i32 %5 = arith.addi %c0_i32, %4 : i32 - // CHECK: return - return %5 : i32 - } + secret.yield %5 : i32 + } -> (!secret.secret) + + // CHECK: secret.cast + // CHECK: return + return %1 : !secret.secret +} diff --git a/tests/yosys_optimizer/micro_speech_for.mlir b/tests/yosys_optimizer/micro_speech_for.mlir index 5ddb0bec7d..4418aa9df3 100644 --- a/tests/yosys_optimizer/micro_speech_for.mlir +++ b/tests/yosys_optimizer/micro_speech_for.mlir @@ -1,40 +1,51 @@ // RUN: heir-opt --yosys-optimizer='abc-fast=true' %s | FileCheck %s -// CHECK: module -module { - func.func @for_25_20_8(%98: i32, %99: i32, %100: i8) -> (i8) { - // The only arith op we expect is arith.constant - // CHECK-NOT: arith.{{^constant}} - // CHECK: comb.truth_table - %c1_i64 = arith.constant 1 : i64 - %c1073741824_i64 = arith.constant 1073741824 : i64 - %c0_i32 = arith.constant 0 : i32 - %c-1073741824_i64 = arith.constant -1073741824 : i64 - %c31_i32 = arith.constant 31 : i32 - %c-128_i32 = arith.constant -128 : i32 - %c127_i32 = arith.constant 127 : i32 - %101 = arith.extui %100 : i8 to i32 - %102 = arith.extsi %98 : i32 to i64 - %103 = arith.extsi %99 : i32 to i64 - %104 = arith.muli %102, %103 : i64 - %105 = arith.extui %100 : i8 to i64 - %106 = arith.shli %c1_i64, %105 : i64 - %107 = arith.shrui %106, %c1_i64 : i64 - %108 = arith.addi %104, %107 : i64 - %109 = arith.cmpi sge, %98, %c0_i32 : i32 - %110 = arith.select %109, %c1073741824_i64, %c-1073741824_i64 : i64 - %111 = arith.addi %110, %108 : i64 - %112 = arith.cmpi sgt, %101, %c31_i32 : i32 - %113 = arith.select %112, %111, %108 : i64 - %114 = arith.shrsi %113, %105 : i64 - %115 = arith.trunci %114 : i64 to i32 - %116 = arith.addi %115, %c-128_i32 : i32 - %117 = arith.cmpi slt, %116, %c-128_i32 : i32 - %118 = arith.select %117, %c-128_i32, %116 : i32 - %119 = arith.cmpi sgt, %116, %c127_i32 : i32 - %120 = arith.select %119, %c127_i32, %118 : i32 - %121 = arith.trunci %120 : i32 to i8 - // CHECK: return - func.return %121 : i8 - } +// CHECK-LABEL: @for_25_20_8 +func.func @for_25_20_8( + %c98: !secret.secret, %c99: !secret.secret, %c100: !secret.secret) -> (!secret.secret) { + // convert three args + // CHECK: secret.cast + // CHECK: secret.cast + // CHECK: secret.cast + + // The only arith op we expect is arith.constant + // CHECK-NOT: arith.{{^constant}} + // CHECK: comb.truth_table + %1 = secret.generic + ins(%c98, %c99, %c100: !secret.secret, !secret.secret, !secret.secret) { + ^bb0(%98: i32, %99: i32, %100: i8) : + %c1_i64 = arith.constant 1 : i64 + %c1073741824_i64 = arith.constant 1073741824 : i64 + %c0_i32 = arith.constant 0 : i32 + %c-1073741824_i64 = arith.constant -1073741824 : i64 + %c31_i32 = arith.constant 31 : i32 + %c-128_i32 = arith.constant -128 : i32 + %c127_i32 = arith.constant 127 : i32 + %101 = arith.extui %100 : i8 to i32 + %102 = arith.extsi %98 : i32 to i64 + %103 = arith.extsi %99 : i32 to i64 + %104 = arith.muli %102, %103 : i64 + %105 = arith.extui %100 : i8 to i64 + %106 = arith.shli %c1_i64, %105 : i64 + %107 = arith.shrui %106, %c1_i64 : i64 + %108 = arith.addi %104, %107 : i64 + %109 = arith.cmpi sge, %98, %c0_i32 : i32 + %110 = arith.select %109, %c1073741824_i64, %c-1073741824_i64 : i64 + %111 = arith.addi %110, %108 : i64 + %112 = arith.cmpi sgt, %101, %c31_i32 : i32 + %113 = arith.select %112, %111, %108 : i64 + %114 = arith.shrsi %113, %105 : i64 + %115 = arith.trunci %114 : i64 to i32 + %116 = arith.addi %115, %c-128_i32 : i32 + %117 = arith.cmpi slt, %116, %c-128_i32 : i32 + %118 = arith.select %117, %c-128_i32, %116 : i32 + %119 = arith.cmpi sgt, %116, %c127_i32 : i32 + %120 = arith.select %119, %c127_i32, %118 : i32 + %121 = arith.trunci %120 : i32 to i8 + secret.yield %121 : i8 + } -> (!secret.secret) + + // CHECK: secret.cast + // CHECK: return + func.return %1 : !secret.secret }