Skip to content

Commit

Permalink
Merge pull request #264 from asraa:tensor-multi-bits
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 585667188
  • Loading branch information
copybara-github committed Nov 27, 2023
2 parents 085e977 + 697da2b commit 60131aa
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 48 deletions.
5 changes: 5 additions & 0 deletions include/Transforms/YosysOptimizer/YosysOptimizer.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@ def YosysOptimizer : Pass<"yosys-optimizer"> {
let description = [{
This pass invokes Yosys to convert an arithmetic circuit to an optimized
boolean circuit that uses the arith and comb dialects.

Note that booleanization changes the function signature: multi-bit integers
are transformed to a tensor of booleans, for example, an `i8` is converted
to `tensor<8xi1>`.
}];

let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::heir::comb::CombDialect",
"mlir::tensor::TensorDialect"
];
}

Expand Down
3 changes: 3 additions & 0 deletions lib/Transforms/YosysOptimizer/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TransformUtils",
],
)
Expand Down Expand Up @@ -57,6 +58,7 @@ cc_test(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
],
)

Expand Down Expand Up @@ -84,6 +86,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)
46 changes: 24 additions & 22 deletions lib/Transforms/YosysOptimizer/LUTImporterTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@
#include "include/Dialect/Comb/IR/CombDialect.h"
#include "include/Dialect/Comb/IR/CombOps.h"
#include "lib/Transforms/YosysOptimizer/LUTImporter.h"
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/Support/Path.h" // from @llvm-project
#include "mlir/include/mlir//IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/Support/Path.h" // from @llvm-project
#include "mlir/include/mlir//IR/Location.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/OwningOpRef.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project

// Block clang-format from reordering
// clang-format off
Expand All @@ -40,7 +41,7 @@ class LUTImporterTestFixture : public Test {
protected:
void SetUp() override {
context.loadDialect<heir::comb::CombDialect, arith::ArithDialect,
func::FuncDialect>();
func::FuncDialect, tensor::TensorDialect>();
module_ = ModuleOp::create(UnknownLoc::get(&context));
Yosys::yosys_setup();
}
Expand Down Expand Up @@ -87,9 +88,9 @@ TEST_F(LUTImporterTestFixture, AddOneLUT3) {

auto funcType = func.getFunctionType();
EXPECT_EQ(funcType.getNumInputs(), 1);
EXPECT_EQ(funcType.getInput(0).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getInput(0)).getNumElements(), 8);
EXPECT_EQ(funcType.getNumResults(), 1);
EXPECT_EQ(funcType.getResult(0).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getResult(0)).getNumElements(), 8);

auto combOps = func.getOps<comb::TruthTableOp>().begin();
for (size_t i = 0; i < expectedLuts.size(); i++) {
Expand All @@ -104,9 +105,9 @@ TEST_F(LUTImporterTestFixture, AddOneLUT5) {

auto funcType = func.getFunctionType();
EXPECT_EQ(funcType.getNumInputs(), 1);
EXPECT_EQ(funcType.getInput(0).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getInput(0)).getNumElements(), 8);
EXPECT_EQ(funcType.getNumResults(), 1);
EXPECT_EQ(funcType.getResult(0).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getResult(0)).getNumElements(), 8);

auto combOps = func.getOps<comb::TruthTableOp>();
for (auto combOp : combOps) {
Expand All @@ -125,12 +126,13 @@ TEST_F(LUTImporterTestFixture, DoubleInput) {

auto funcType = func.getFunctionType();
EXPECT_EQ(funcType.getNumInputs(), 1);
EXPECT_EQ(funcType.getInput(0).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getInput(0)).getNumElements(), 8);
EXPECT_EQ(funcType.getNumResults(), 1);
EXPECT_EQ(funcType.getResult(0).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getResult(0)).getNumElements(), 8);

auto returnOp = *func.getOps<func::ReturnOp>().begin();
auto concatOp = returnOp.getOperands()[0].getDefiningOp<comb::ConcatOp>();
auto concatOp =
returnOp.getOperands()[0].getDefiningOp<tensor::FromElementsOp>();
ASSERT_TRUE(concatOp);
EXPECT_EQ(concatOp->getNumOperands(), 8);
arith::ConstantOp constOp =
Expand All @@ -147,10 +149,10 @@ TEST_F(LUTImporterTestFixture, MultipleInputs) {

auto funcType = func.getFunctionType();
EXPECT_EQ(funcType.getNumInputs(), 2);
EXPECT_EQ(funcType.getInput(0).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(funcType.getInput(1).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getInput(0)).getNumElements(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getInput(1)).getNumElements(), 8);
EXPECT_EQ(funcType.getNumResults(), 1);
EXPECT_EQ(funcType.getResult(0).getIntOrFloatBitWidth(), 8);
EXPECT_EQ(cast<RankedTensorType>(funcType.getResult(0)).getNumElements(), 8);
}

} // namespace
Expand Down
42 changes: 29 additions & 13 deletions lib/Transforms/YosysOptimizer/RTLILImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
#include "include/Dialect/Comb/IR/CombOps.h"
#include "kernel/rtlil.h" // from @at_clifford_yosys
#include "llvm/include/llvm/ADT/MapVector.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/FoldUtils.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/FoldUtils.h" // from @llvm-project

namespace mlir {
namespace heir {
Expand All @@ -25,6 +26,21 @@ using ::Yosys::RTLIL::Module;
using ::Yosys::RTLIL::SigSpec;
using ::Yosys::RTLIL::Wire;

namespace {

// getTypeForWire gets the MLIR type corresponding to the RTLIL wire. If the
// wire is an integer with multiple bits, then the MLIR type is a tensor of
// bits.
Type getTypeForWire(OpBuilder &b, Wire *wire) {
auto intTy = b.getI1Type();
if (wire->width == 1) {
return intTy;
}
return RankedTensorType::get({wire->width}, intTy);
}

} // namespace

llvm::SmallVector<std::string, 10> getTopologicalOrder(
std::stringstream &torderOutput) {
llvm::SmallVector<std::string, 10> cells;
Expand Down Expand Up @@ -75,8 +91,8 @@ Value RTLILImporter::getBit(
return retBitValues[bit.wire][offset];
}
auto argA = getWireValue(bit.wire);
auto extractOp =
b.createOrFold<comb::ExtractOp>(b.getI1Type(), argA, bit.offset);
auto extractOp = b.create<tensor::ExtractOp>(
argA, b.create<arith::ConstantIndexOp>(bit.offset).getResult());
return extractOp;
}

Expand Down Expand Up @@ -112,10 +128,10 @@ func::FuncOp RTLILImporter::importModule(
// The RTLIL module may also have intermediate wires that are neither inputs
// nor outputs.
if (wire->port_input) {
argTypes.push_back(builder.getIntegerType(wire->width));
argTypes.push_back(getTypeForWire(builder, wire));
wireArgs.push_back(wire);
} else if (wire->port_output) {
retTypes.push_back(builder.getIntegerType(wire->width));
retTypes.push_back(getTypeForWire(builder, wire));
wireRet.push_back(wire);
retBitValues[wire].resize(wire->width);
}
Expand Down Expand Up @@ -186,7 +202,7 @@ func::FuncOp RTLILImporter::importModule(
} else {
// We are in a multi-bit scenario.
assert(retBits.size() > 1);
auto concatOp = b.create<comb::ConcatOp>(retBits);
auto concatOp = b.create<tensor::FromElementsOp>(retBits);
returnValues.push_back(concatOp.getResult());
}
}
Expand Down
27 changes: 15 additions & 12 deletions lib/Transforms/YosysOptimizer/YosysOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@
#include "include/Target/Verilog/VerilogEmitter.h"
#include "lib/Transforms/YosysOptimizer/LUTImporter.h"
#include "lib/Transforms/YosysOptimizer/RTLILImporter.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/DialectRegistry.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project

// Block clang-format from reordering
// clang-format off
Expand Down Expand Up @@ -63,7 +64,8 @@ struct YosysOptimizer : public impl::YosysOptimizerBase<YosysOptimizer> {
: yosysFilesPath(yosysFilesPath), abcPath(abcPath), abcFast(abcFast) {}

void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<comb::CombDialect, mlir::arith::ArithDialect>();
registry.insert<comb::CombDialect, mlir::arith::ArithDialect,
mlir::tensor::TensorDialect>();
}

void runOnOperation() override;
Expand Down Expand Up @@ -118,6 +120,7 @@ void YosysOptimizer::runOnOperation() {
<< "Converted & optimized func via yosys. Input func:\n"
<< op << "\n\nOutput func:\n"
<< func << "\n");
op.setFunctionType(func.getFunctionType());
op.getBody().takeBody(func.getBody());

return WalkResult::advance();
Expand Down
2 changes: 1 addition & 1 deletion tests/yosys_optimizer/add_one.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ module {
%0 = arith.constant 1 : i8
// CHECK-NOT arith.addi
%1 = arith.addi %in, %0 : i8
// CHECK: comb.concat
// CHECK: tensor.from_elements
// CHECK-NEXT: return
return %1 : i8
}
Expand Down

0 comments on commit 60131aa

Please sign in to comment.