diff --git a/.gitignore b/.gitignore index e48b5799b..09ffade70 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,4 @@ # Build-related directories build/ bin/ -.cache/ \ No newline at end of file +.cache/ diff --git a/include/dynamatic/Transforms/InitCstWidth.h b/include/dynamatic/Transforms/InitCstWidth.h new file mode 100644 index 000000000..a94bea4ea --- /dev/null +++ b/include/dynamatic/Transforms/InitCstWidth.h @@ -0,0 +1,18 @@ +//===- InitCstWidth.h - Reduce the constant bits width ----------*- C++ -*-===// +// +// This file declares the --init-cstwidth pass. +// +//===----------------------------------------------------------------------===// + +#ifndef DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H +#define DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H + +#include "dynamatic/Transforms/UtilsBitsUpdate.h" + +namespace dynamatic { + +std::unique_ptr> createInitCstWidthPass(); + +} // namespace dynamatic + +#endif // DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H \ No newline at end of file diff --git a/include/dynamatic/Transforms/InitIndexType.h b/include/dynamatic/Transforms/InitIndexType.h new file mode 100644 index 000000000..1cdbf3afb --- /dev/null +++ b/include/dynamatic/Transforms/InitIndexType.h @@ -0,0 +1,18 @@ +//===- InitIndexType.h - Transform IndexType to IntegerType -----*- C++ -*-===// +// +// This file declares the --init-indextype pass. +// +//===----------------------------------------------------------------------===// + +#ifndef DYNAMATIC_TRANSFORMS_INITINDTYPE_H +#define DYNAMATIC_TRANSFORMS_INITINDTYPE_H + +#include "dynamatic/Transforms/UtilsBitsUpdate.h" + +namespace dynamatic { + +std::unique_ptr> createInitIndTypePass(); + +} // namespace dynamatic + +#endif // DYNAMATIC_TRANSFORMS_INITINDTYPE_H \ No newline at end of file diff --git a/include/dynamatic/Transforms/OptimizeBits.h b/include/dynamatic/Transforms/OptimizeBits.h new file mode 100644 index 000000000..04f064bca --- /dev/null +++ b/include/dynamatic/Transforms/OptimizeBits.h @@ -0,0 +1,18 @@ +//===- BitsOptimize.h - Optimize bits widths --------------------*- C++ -*-===// +// +// This file declares the --optimize-bits pass. +// +//===----------------------------------------------------------------------===// + +#ifndef DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H +#define DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H + +#include "dynamatic/Transforms/UtilsBitsUpdate.h" + +namespace dynamatic { + +std::unique_ptr> createOptimizeBitsPass(); + +} // namespace dynamatic + +#endif // DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H \ No newline at end of file diff --git a/include/dynamatic/Transforms/Passes.h b/include/dynamatic/Transforms/Passes.h index dcb987884..39f268a98 100644 --- a/include/dynamatic/Transforms/Passes.h +++ b/include/dynamatic/Transforms/Passes.h @@ -9,10 +9,13 @@ #include "dynamatic/Support/LLVM.h" #include "dynamatic/Transforms/AnalyzeMemoryAccesses.h" +#include "dynamatic/Transforms/OptimizeBits.h" #include "dynamatic/Transforms/ArithReduceArea.h" #include "dynamatic/Transforms/FlattenMemRefRowMajor.h" #include "dynamatic/Transforms/HandshakeInferBasicBlocks.h" #include "dynamatic/Transforms/HandshakePrepareForLegacy.h" +#include "dynamatic/Transforms/InitCstWidth.h" +#include "dynamatic/Transforms/InitIndexType.h" #include "dynamatic/Transforms/NameMemoryOps.h" #include "dynamatic/Transforms/PushConstants.h" #include "mlir/Pass/Pass.h" @@ -25,4 +28,4 @@ namespace dynamatic { } // namespace dynamatic -#endif // DYNAMATIC_TRANSFORMS_PASSES_H +#endif // DYNAMATIC_TRANSFORMS_PASSES_H \ No newline at end of file diff --git a/include/dynamatic/Transforms/Passes.td b/include/dynamatic/Transforms/Passes.td index 7ad3ad301..c4e937b1c 100644 --- a/include/dynamatic/Transforms/Passes.td +++ b/include/dynamatic/Transforms/Passes.td @@ -126,4 +126,32 @@ def PushConstants : Pass<"push-constants", "mlir::ModuleOp"> { let constructor = "dynamatic::createPushConstantsPass()"; } +def HandshakeOptimizeBits : Pass<"optimize-bits" , "mlir::ModuleOp"> { + let summary = "Optimize bits width that can be reduced."; + let description = [{ + This pass goes through all operations inside handshake levels, and optimize the bit width through the forward (from first operation to last), and backward(from last operation to first) loop process. + The loop continues until no more optimization can be done. + }]; + let constructor = "dynamatic::createOptimizeBitsPass()"; + let dependentDialects = ["mlir::arith::ArithDialect"]; +} + +def HandshakeInitIndType : Pass<"init-indtype", "mlir::ModuleOp"> { + let summary = "Initialize the index type of the module."; + let description = [{ + This pass change all the index type within operands and result operands to integer type with platform dependent bit width. + }]; + let constructor = "dynamatic::createInitIndTypePass()"; +} + +def HandshakeInitCstWidth : Pass<"init-cstwidth", "mlir::ModuleOp"> { + let summary = "Initialize the constant bit witth of the module."; + let description = [{ + This pass rewrites constant operation with the minimum required bit width according to the value of the constant. To ensure the consistency with the user of the constant operation, the pass inserts a extension operation if necessary. + }]; + let constructor = "dynamatic::createInitCstWidthPass()"; + let dependentDialects = ["mlir::arith::ArithDialect"]; + +} + #endif // DYNAMATIC_TRANSFORMS_PASSES_TD \ No newline at end of file diff --git a/include/dynamatic/Transforms/UtilsBitsUpdate.h b/include/dynamatic/Transforms/UtilsBitsUpdate.h new file mode 100644 index 000000000..fccb06b0c --- /dev/null +++ b/include/dynamatic/Transforms/UtilsBitsUpdate.h @@ -0,0 +1,86 @@ +//===- UtilsBitsUpdate.h - Utils support bits optimization ------*- C++ -*-===// +// +// This file declares supports for --optimize-bits pass. +// +//===----------------------------------------------------------------------===// + +#ifndef DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H +#define DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H + +#include "circt/Dialect/Handshake/HandshakeOps.h" +#include "dynamatic/Support/LLVM.h" +#include + +using namespace mlir; +using namespace circt; +using namespace circt::handshake; +using namespace dynamatic; + +const unsigned CPP_MAX_WIDTH = 64; +const unsigned ADDRESS_WIDTH = 32; + +IntegerType getNewType(Value opVal, unsigned bitswidth, bool signless = false); + +IntegerType getNewType(Value opVal, unsigned bitswidth, + IntegerType::SignednessSemantics ifSign); + +std::optional insertWidthMatchOp(Operation *newOp, int opInd, + Type newType, MLIRContext *ctx); + +namespace dynamatic { +namespace bitwidth { + +/// Construct the functions w.r.t. the operation name in the forward process. +void constructForwardFuncMap( + DenseMap> + &mapOpNameWidth); + +/// Construct the functions w.r.t. the operation name in the backward process. +void constructBackwardFuncMap( + DenseMap> + &mapOpNameWidth); + +/// Construct the functions w.r.t. the operation name in the validation process. +void constructUpdateFuncMap( + DenseMap> + &mapOpNameWidth); + +/// For branch and conditional branch operations, propagate the bits width of +/// the operands to the result. +bool propType(Operation *Op); + +/// Insert width match operations (extension or truncation) for the operands and +/// the results. +void matchOpResWidth(Operation *op, MLIRContext *ctx, + SmallVector &newMatchedOps); + +/// Replace the operation's operand with the its successor. +void replaceWithPredecessor(Operation *op); + +/// Replace the operation's operand with the its successor, and set the operation's +/// resultOp according to its successor's resultOp type. +void replaceWithPredecessor(Operation *op, Type resType); + +// Validate the truncation and extension operation in case its operand and +// result operand width are not consistent by reverting or deleting the +// operations. +void revertTruncOrExt(Operation *op, MLIRContext *ctx); + +/// Set the pass, match, and revert flags to choose the methods that validate +/// the operations. +static bool setPassFlag(Operation *op); + +static bool setMatchFlag(Operation *op); + +static bool setRevertFlag(Operation *op); + +/// Validate the operations after bits optimization to generate .mlir file. +void validateOp(Operation *op, MLIRContext *ctx, + SmallVector &newMatchedOps); +} // namespace bitwidth +} // namespace dynamatic + +#endif // DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H \ No newline at end of file diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index 54b391596..259dcca23 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -6,6 +6,10 @@ add_dynamatic_library(DynamaticTransforms HandshakeInferBasicBlocks.cpp NameMemoryOps.cpp PushConstants.cpp + UtilsBitsUpdate.cpp + OptimizeBits.cpp + InitIndexType.cpp + InitCstWidth.cpp DEPENDS DynamaticTransformsPassIncGen diff --git a/lib/Transforms/InitCstWidth.cpp b/lib/Transforms/InitCstWidth.cpp new file mode 100644 index 000000000..81e89f22a --- /dev/null +++ b/lib/Transforms/InitCstWidth.cpp @@ -0,0 +1,98 @@ +//===- InitCstWidth.cpp - Reduce the constant bits width --------*- C++ -*-===// +// +// This file contains the implementation of the init-cstwidth pass. +// +//===----------------------------------------------------------------------===// + +#include "dynamatic/Transforms/InitCstWidth.h" +#include "circt/Dialect/Handshake/HandshakeOps.h" +#include "dynamatic/Transforms/PassDetails.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/IndentedOstream.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "BITWIDTH" + +static LogicalResult initCstOpBitsWidth(handshake::FuncOp funcOp, + MLIRContext *ctx) { + OpBuilder builder(ctx); + SmallVector cstOps; + + int savedBits = 0; + + for (auto op : + llvm::make_early_inc_range(funcOp.getOps())) { + unsigned cstBitWidth = CPP_MAX_WIDTH; + IntegerType::SignednessSemantics ifSign = + IntegerType::SignednessSemantics::Signless; + // skip the bool value constant operation + if (!isa(op.getValue())) + continue; + + // get the attribute value + if (auto intAttr = dyn_cast(op.getValue())) { + if (int cstVal = intAttr.getValue().getZExtValue(); cstVal > 0) + cstBitWidth = log2(cstVal) + 2; + else if (int cstVal = intAttr.getValue().getZExtValue(); cstVal < 0) { + cstBitWidth = log2(-cstVal) + 2; + } else + cstBitWidth = 2; + } + + if (cstBitWidth < op.getResult().getType().getIntOrFloatBitWidth()) { + // Get the new type of calculated bitwidth + Type newType = getNewType(op.getResult(), cstBitWidth, ifSign); + + // Update the constant operator for both ValueAttr and result Type + builder.setInsertionPointAfter(op); + handshake::ConstantOp newCstOp = builder.create( + op.getLoc(), newType, op.getValue(), op.getCtrl()); + + // Determine the proper representation of the constant value + int intVal = op.getValue().cast().getInt(); + intVal = + ((1 << op.getValue().getType().getIntOrFloatBitWidth()) - 1 + intVal); + newCstOp.setValueAttr(IntegerAttr::get(newType, intVal)); + // save the original bb + newCstOp->setAttr("bb", op->getAttr("bb")); + + // recursively replace the uses of the old constant operation with the new + // one Value opVal = op.getResult(); + savedBits += + op.getResult().getType().getIntOrFloatBitWidth() - cstBitWidth; + auto extOp = builder.create( + newCstOp.getLoc(), op.getResult().getType(), newCstOp.getResult()); + + // replace the constant operation (default width) + // with the extension of new constant operation (optimized width) + op->replaceAllUsesWith(extOp); + } + } + + LLVM_DEBUG(llvm::dbgs() << "Number of saved bits is " << savedBits << "\n"); + + return success(); +} + +struct HandshakeInitCstWidthPass + : public HandshakeInitCstWidthBase { + + void runOnOperation() override { + auto *ctx = &getContext(); + + ModuleOp m = getOperation(); + for (auto funcOp : m.getOps()) + if (failed(initCstOpBitsWidth(funcOp, ctx))) + return signalPassFailure(); + }; +}; + +std::unique_ptr> +dynamatic::createInitCstWidthPass() { + return std::make_unique(); +} diff --git a/lib/Transforms/InitIndexType.cpp b/lib/Transforms/InitIndexType.cpp new file mode 100644 index 000000000..a1a3f4bde --- /dev/null +++ b/lib/Transforms/InitIndexType.cpp @@ -0,0 +1,117 @@ +//===- InitIndexType.cpp - Change IndexType to IntegerType ------*- C++ -*-===// +// +// This file contains the implementation of the init-indextype optimization +// pass. +// +//===----------------------------------------------------------------------===// + +#include "dynamatic/Transforms/InitIndexType.h" +#include "circt/Dialect/Handshake/HandshakeOps.h" +#include "dynamatic/Transforms/PassDetails.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Support/LogicalResult.h" + +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/IndentedOstream.h" +#include "mlir/Support/LogicalResult.h" + +using namespace circt; +using namespace circt::handshake; +using namespace mlir; +using namespace dynamatic; + +static const int INDEXWIDTH = IndexType::kInternalStorageBitWidth; + +// Adapt the Index type to the Integer type +static LogicalResult initIndexType(handshake::FuncOp funcOp, MLIRContext *ctx) { + OpBuilder builder(ctx); + SmallVector indexCastOps; + + for (auto &op : llvm::make_early_inc_range(funcOp.getOps())) { + // insert trunc|extsi operation for index_cast operation + if (isa(op)) { + indexCastOps.push_back(&op); + auto indexCastOp = dyn_cast(op); + bool isOpIndType = isa(indexCastOp.getOperand().getType()); + bool isResIndType = isa(indexCastOp.getResult().getType()); + + // if cast index to integer type + if (!isResIndType) { + if (indexCastOp.getResult().getType().getIntOrFloatBitWidth() == + INDEXWIDTH) + indexCastOp.getResult().replaceAllUsesWith(indexCastOp.getOperand()); + else { + auto newOp = + insertWidthMatchOp(&op, 0, op.getResult(0).getType(), ctx); + if (newOp.has_value()) + op.replaceAllUsesWith(newOp.value()); + } + } + + // if cast integer to index type + if (isResIndType && !isOpIndType) { + if (indexCastOp.getOperand().getType().getIntOrFloatBitWidth() == + INDEXWIDTH) + indexCastOp.getResult().replaceAllUsesWith(indexCastOp.getOperand()); + else { + builder.setInsertionPoint(&op); + auto extOp = builder.create( + op.getLoc(), IntegerType::get(ctx, INDEXWIDTH), + indexCastOp.getOperand()); + op.replaceAllUsesWith(extOp); + } + } + } + + // set type for other operations + else { + for (unsigned int i = 0; i < op.getNumOperands(); ++i) + if (auto operand = op.getOperand(i); + isa(operand.getType())) + operand.setType(IntegerType::get(ctx, INDEXWIDTH)); + + for (unsigned int i = 0; i < op.getNumResults(); ++i) + if (OpResult result = op.getResult(i); + isa(result.getType())) { + result.setType(IntegerType::get(ctx, INDEXWIDTH)); + + // For constant operation, change the value attribute to match the new + // type + if (isa(op)) { + handshake::ConstantOp cstOp = dyn_cast(op); + cstOp.setValueAttr(IntegerAttr::get( + IntegerType::get(ctx, INDEXWIDTH), + cstOp.getValue().cast().getInt())); + } + } + } + } + + for (auto op : indexCastOps) + op->erase(); + + return success(); +} + +namespace { + +struct HandshakeInitIndTypePass + : public HandshakeInitIndTypeBase { + + void runOnOperation() override { + auto *ctx = &getContext(); + + ModuleOp m = getOperation(); + for (auto funcOp : m.getOps()) + if (failed(initIndexType(funcOp, ctx))) + return signalPassFailure(); + }; +}; +}; // namespace + +std::unique_ptr> +dynamatic::createInitIndTypePass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Transforms/OptimizeBits.cpp b/lib/Transforms/OptimizeBits.cpp new file mode 100644 index 000000000..7b4129881 --- /dev/null +++ b/lib/Transforms/OptimizeBits.cpp @@ -0,0 +1,158 @@ +//===- BitsOptimize.cpp - Optimize bits width ------------------*- C++ -*-===// +// +// This file contains the implementation of the bits optimization pass. +// +//===----------------------------------------------------------------------===// + +#include "dynamatic/Transforms/OptimizeBits.h" +#include "circt/Dialect/Handshake/HandshakeOps.h" +#include "dynamatic/Transforms/PassDetails.h" +#include "dynamatic/Transforms/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/IndentedOstream.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "BITWIDTH" + +using namespace circt; +using namespace circt::handshake; +using namespace mlir; +using namespace dynamatic; + +static LogicalResult rewriteBitsWidths(handshake::FuncOp funcOp, + MLIRContext *ctx) { + OpBuilder builder(ctx); + SmallVector vecOp; + + using forward_func = + std::function; + using backward_func = + std::function; + + SmallVector containerOps; + + bool changed = true; + int savedBits = 0; + + // Construct the functions w.r.t. the operation name for forward and backward + DenseMap forMapOpNameWidth; + bitwidth::constructForwardFuncMap(forMapOpNameWidth); + + DenseMap backMapOpNameWidth; + bitwidth::constructBackwardFuncMap(backMapOpNameWidth); + + while (changed) { + // init + changed = false; + containerOps.clear(); + + for (auto &op : funcOp.getOps()) + containerOps.push_back(&op); + + // Forward process + for (auto &op : containerOps) { + + if (isa(*op) || + bitwidth::propType(op)) + continue; + + if (isa(op) || isa(op)) { + bitwidth::replaceWithPredecessor(op); + // op->erase(); + continue; + } + + const auto opName = op->getName().getStringRef(); + unsigned int newWidth = 0, resInd = 0; + if (forMapOpNameWidth.find(opName) != forMapOpNameWidth.end()) + newWidth = forMapOpNameWidth[opName](op->getOperands()); + else + continue; + + if (isa(op)) + resInd = 1; // the second result is the one that needs to be updated + + // if the new type can be optimized, update the type + if (Type newOpResultType = + getNewType(op->getResult(resInd), newWidth, true); + newWidth < + op->getResult(resInd).getType().getIntOrFloatBitWidth()) { + changed = true; + savedBits += op->getResult(resInd).getType().getIntOrFloatBitWidth() - + newWidth; + op->getResult(resInd).setType(newOpResultType); + } + } + + // Backward Process + for (auto op : llvm::reverse(containerOps)) { + if (isa(*op)) + continue; + + if (isa(*op)) { + bitwidth::replaceWithPredecessor(op, op->getResult(0).getType()); + // op->erase(); + continue; + } + + const auto opName = op->getName().getStringRef(); + + unsigned int newWidth = 0; + // get the new bit width of the result operator + if (backMapOpNameWidth.find(opName) != backMapOpNameWidth.end()) + newWidth = backMapOpNameWidth[opName](op->getResults()); + else + continue; + + // if the new type can be optimized, update the type + if (Type newOpResultType = + getNewType(op->getOperand(0), newWidth, true); + newWidth < op->getOperand(0).getType().getIntOrFloatBitWidth()) { + changed = true; + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + unsigned int origWidth = + op->getOperand(i).getType().getIntOrFloatBitWidth(); + if (newWidth < origWidth) { + savedBits += origWidth - newWidth; + op->getOperand(i).setType(newOpResultType); + } + } + } + } + } + + // Store new inserted truncation or extension operation during validation + SmallVector OpTruncExt; + for (auto &op : llvm::make_early_inc_range(funcOp.getOps())) + bitwidth::validateOp(&op, ctx, OpTruncExt); + + // Validate the new inserted operation + for (auto op : OpTruncExt) + bitwidth::revertTruncOrExt(op, ctx); + + LLVM_DEBUG(llvm::errs() << "Forward-Backward saved bits " << savedBits << "\n"); + + return success(); +} + +struct HandshakeOptimizeBitsPass + : public HandshakeOptimizeBitsBase { + + void runOnOperation() override { + auto *ctx = &getContext(); + + ModuleOp m = getOperation(); + + for (auto funcOp : m.getOps()) + if (failed(rewriteBitsWidths(funcOp, ctx))) + return signalPassFailure(); + }; +}; + +std::unique_ptr> +dynamatic::createOptimizeBitsPass() { + return std::make_unique(); +} \ No newline at end of file diff --git a/lib/Transforms/UtilsBitsUpdate.cpp b/lib/Transforms/UtilsBitsUpdate.cpp new file mode 100644 index 000000000..a91b0098d --- /dev/null +++ b/lib/Transforms/UtilsBitsUpdate.cpp @@ -0,0 +1,696 @@ +//===- UtilsBitsUpdate.cpp - Utils support bits optimization ----*- C++ -*-===// +// +// This file contains basic functions for type updates for --optimize-bits pass. +// +//===----------------------------------------------------------------------===// + +#include "dynamatic/Transforms/UtilsBitsUpdate.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Support/IndentedOstream.h" +#include "llvm/ADT/TypeSwitch.h" + +IntegerType getNewType(Value opVal, unsigned bitswidth, bool signless) { + IntegerType::SignednessSemantics ifSign = + IntegerType::SignednessSemantics::Signless; + if (!signless) + if (auto validType = dyn_cast(opVal.getType())) + ifSign = validType.getSignedness(); + + return IntegerType::get(opVal.getContext(), bitswidth, ifSign); +} + +IntegerType getNewType(Value opVal, unsigned bitswidth, + IntegerType::SignednessSemantics ifSign) { + return IntegerType::get(opVal.getContext(), bitswidth, ifSign); +} + +// specify which value to extend +std::optional insertWidthMatchOp(Operation *newOp, int opInd, + Type newType, MLIRContext *ctx) { + OpBuilder builder(ctx); + Value opVal = newOp->getOperand(opInd); + + if (!isa(opVal.getType())) + assert(false && "Only supported width matching for Integer/Index Type!"); + + unsigned int opWidth; + if (isa(opVal.getType())) + opWidth = IndexType::kInternalStorageBitWidth; + else + opWidth = opVal.getType().getIntOrFloatBitWidth(); + + if (isa(opVal.getType())) { + // insert Truncation operation to match the opresult width + if (opWidth > newType.getIntOrFloatBitWidth()) { + builder.setInsertionPoint(newOp); + auto truncOp = builder.create(newOp->getLoc(), + newType, opVal); + newOp->setOperand(opInd, truncOp.getResult()); + + return truncOp; + } + + // insert Extension operation to match the opresult width + if (opWidth < newType.getIntOrFloatBitWidth()) { + builder.setInsertionPoint(newOp); + auto extOp = + builder.create(newOp->getLoc(), newType, opVal); + newOp->setOperand(opInd, extOp.getResult()); + + return extOp; + } + } + return {}; +} + +namespace dynamatic::bitwidth { + +void constructForwardFuncMap( + DenseMap> + &mapOpNameWidth) { + + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { + return std::min( + CPP_MAX_WIDTH, + std::max(vecOperands[0].getType().getIntOrFloatBitWidth(), + vecOperands[1].getType().getIntOrFloatBitWidth()) + + 1); + }; + + mapOpNameWidth[mlir::arith::SubIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::MulIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { + return std::min(CPP_MAX_WIDTH, + vecOperands[0].getType().getIntOrFloatBitWidth() + + vecOperands[1].getType().getIntOrFloatBitWidth()); + }; + + mapOpNameWidth[mlir::arith::DivUIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { + return std::min(CPP_MAX_WIDTH, + vecOperands[0].getType().getIntOrFloatBitWidth() + 1); + }; + mapOpNameWidth[mlir::arith::DivSIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::DivUIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::AndIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { + return std::min( + CPP_MAX_WIDTH, + std::min(vecOperands[0].getType().getIntOrFloatBitWidth(), + vecOperands[1].getType().getIntOrFloatBitWidth())); + }; + + mapOpNameWidth[mlir::arith::OrIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { + return std::min( + CPP_MAX_WIDTH, + std::max(vecOperands[0].getType().getIntOrFloatBitWidth(), + vecOperands[1].getType().getIntOrFloatBitWidth())); + }; + + mapOpNameWidth[mlir::arith::XOrIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::OrIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::ShLIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { + int shift_bit = 0; + if (auto defOp = vecOperands[1].getDefiningOp(); + isa(defOp)) { + if (handshake::ConstantOp cstOp = + dyn_cast(defOp)) + if (auto IntAttr = cstOp.getValue().dyn_cast()) + shift_bit = IntAttr.getValue().getZExtValue(); + return std::min(CPP_MAX_WIDTH, + vecOperands[0].getType().getIntOrFloatBitWidth() + + shift_bit); + } + return CPP_MAX_WIDTH; + }; + + mapOpNameWidth[mlir::arith::ShRSIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { + int shiftBit = 0; + if (auto defOp = vecOperands[1].getDefiningOp(); + isa(defOp)) + if (handshake::ConstantOp cstOp = + dyn_cast(defOp)) + if (auto IntAttr = cstOp.getValue().dyn_cast()) + shiftBit = IntAttr.getValue().getZExtValue(); + + return std::min(CPP_MAX_WIDTH, + vecOperands[0].getType().getIntOrFloatBitWidth() - + shiftBit); + }; + + mapOpNameWidth[mlir::arith::ShRUIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::ShRSIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::CmpIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { return unsigned(1); }; + + mapOpNameWidth[mlir::arith::ExtSIOp::getOperationName()] = + [](Operation::operand_range vecOperands) { + return vecOperands[0].getType().getIntOrFloatBitWidth(); + }; + + mapOpNameWidth[mlir::arith::ExtUIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::ExtSIOp::getOperationName()]; + + mapOpNameWidth[StringRef("handshake.control_merge")] = + [](Operation::operand_range vecOperands) { + unsigned ind = 0; // record number of operators + + for (auto oprand : vecOperands) + ind++; + + unsigned indexWidth = 1; + if (ind > 1) + indexWidth = ceil(log2(ind)); + + return indexWidth; + }; +}; + +void constructBackwardFuncMap( + DenseMap> + &mapOpNameWidth) { + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()] = + [](Operation::result_range vecResults) { + return std::min(CPP_MAX_WIDTH, + vecResults[0].getType().getIntOrFloatBitWidth()); + }; + + mapOpNameWidth[mlir::arith::SubIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::MulIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::AndIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::OrIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::XOrIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()]; +} + +void constructUpdateFuncMap( + DenseMap>( + Operation::operand_range vecOperands, + Operation::result_range vecResults)>> + &mapOpNameWidth) { + + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + + unsigned int maxOpWidth = + std::max(vecOperands[0].getType().getIntOrFloatBitWidth(), + vecOperands[1].getType().getIntOrFloatBitWidth()); + + unsigned int width = std::min( + vecResults[0].getType().getIntOrFloatBitWidth(), maxOpWidth + 1); + + width = std::min(CPP_MAX_WIDTH, width); + widths.push_back({width, width}); // matched widths for operators + widths.push_back({width}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[mlir::arith::SubIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::AddIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::MulIOp::getOperationName()] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + + unsigned int maxOpWidth = + vecOperands[0].getType().getIntOrFloatBitWidth() + + vecOperands[1].getType().getIntOrFloatBitWidth(); + + unsigned int width = std::min( + vecResults[0].getType().getIntOrFloatBitWidth(), maxOpWidth); + + width = std::min(CPP_MAX_WIDTH, width); + + widths.push_back({width, width}); // matched widths for operators + widths.push_back({width}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[mlir::arith::DivSIOp::getOperationName()] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + + unsigned int maxOpWidth = + vecOperands[0].getType().getIntOrFloatBitWidth(); + + unsigned int width = std::min( + vecResults[0].getType().getIntOrFloatBitWidth(), maxOpWidth + 1); + + width = std::min(CPP_MAX_WIDTH, width); + + widths.push_back({width, width}); // matched widths for operators + widths.push_back({width}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[mlir::arith::DivUIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::DivSIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::DivSIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::DivSIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::ShLIOp::getOperationName()] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + unsigned shiftBit = 0; + if (auto defOp = vecOperands[1].getDefiningOp(); + isa(defOp)) + if (handshake::ConstantOp cstOp = + dyn_cast(defOp)) + if (auto IntAttr = cstOp.getValue().dyn_cast()) + shiftBit = IntAttr.getValue().getZExtValue(); + + unsigned int width = std::min( + std::min(CPP_MAX_WIDTH, + vecResults[0].getType().getIntOrFloatBitWidth()), + vecOperands[0].getType().getIntOrFloatBitWidth() + shiftBit); + + width = std::min(CPP_MAX_WIDTH, width); + widths.push_back({width, width}); // matched widths for operators + widths.push_back({width}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[mlir::arith::ShRSIOp::getOperationName()] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + unsigned shiftBit = 0; + if (auto defOp = vecOperands[1].getDefiningOp(); + isa(defOp)) + if (handshake::ConstantOp cstOp = + dyn_cast(defOp)) + if (auto IntAttr = cstOp.getValue().dyn_cast()) + shiftBit = IntAttr.getValue().getZExtValue(); + + unsigned int width = std::min( + std::min(CPP_MAX_WIDTH, + vecResults[0].getType().getIntOrFloatBitWidth()), + vecOperands[0].getType().getIntOrFloatBitWidth() - shiftBit); + + width = std::min(CPP_MAX_WIDTH, width); + widths.push_back({width, width}); // matched widths for operators + widths.push_back({width}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[mlir::arith::ShRUIOp::getOperationName()] = + mapOpNameWidth[mlir::arith::ShRSIOp::getOperationName()]; + + mapOpNameWidth[mlir::arith::CmpIOp::getOperationName()] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + + unsigned int maxOpWidth = + std::max(vecOperands[0].getType().getIntOrFloatBitWidth(), + vecOperands[1].getType().getIntOrFloatBitWidth()); + + unsigned int width = std::min(CPP_MAX_WIDTH, maxOpWidth); + + widths.push_back({width, width}); // matched widths for operators + widths.push_back({unsigned(1)}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[StringRef("handshake.mux")] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + unsigned maxOpWidth = 2; + + unsigned ind = 0; // record number of operators + + for (auto oprand : vecOperands) { + ind++; + if (ind == 0) + continue; // skip the width of the index + + if (!isa(oprand.getType())) + if (!isa(oprand.getType()) && + oprand.getType().getIntOrFloatBitWidth() > maxOpWidth) + maxOpWidth = oprand.getType().getIntOrFloatBitWidth(); + } + unsigned indexWidth = 2; + if (ind > 2) + indexWidth = log2(ind - 2) + 2; + + widths.push_back( + {indexWidth}); // the bit width for the mux index result; + + if (isa(vecResults[0].getType())) { + widths.push_back({}); + return widths; + } + + unsigned int width = + std::min(vecResults[0].getType().getIntOrFloatBitWidth(), + std::min(CPP_MAX_WIDTH, maxOpWidth)); + // 1st operand is the index; rest of (ind -1) operands set to width + std::vector opwidths(ind - 1, width); + + widths[0].insert(widths[0].end(), opwidths.begin(), + opwidths.end()); // matched widths for operators + widths.push_back({width}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[StringRef("handshake.merge")] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + unsigned maxOpWidth = 2; + + unsigned ind = 0; // record number of operators + + for (auto oprand : vecOperands) { + ind++; + if (!isa(vecOperands[0].getType())) + if (!isa(oprand.getType()) && + oprand.getType().getIntOrFloatBitWidth() > maxOpWidth) + maxOpWidth = oprand.getType().getIntOrFloatBitWidth(); + } + + if (isa(vecOperands[0].getType())) { + widths.push_back({}); + widths.push_back({}); + return widths; + } + + unsigned int width = + std::min(vecResults[0].getType().getIntOrFloatBitWidth(), + std::min(CPP_MAX_WIDTH, maxOpWidth)); + std::vector opwidths(ind, width); + + widths.push_back(opwidths); // matched widths for operators + widths.push_back({width}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[StringRef("handshake.constant")] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + // Do not set the width of the input + widths.push_back({}); + Operation *Op = vecResults[0].getDefiningOp(); + if (handshake::ConstantOp cstOp = dyn_cast(*Op)) + if (auto IntAttr = cstOp.getValueAttr().dyn_cast()) + if (auto IntType = dyn_cast(IntAttr.getType())) { + widths.push_back({IntType.getWidth()}); + return widths; + } + + widths.push_back({}); + return widths; + }; + + mapOpNameWidth[StringRef("handshake.control_merge")] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + unsigned maxOpWidth = 2; + + unsigned ind = 0; // record number of operators + + for (auto oprand : vecOperands) { + ind++; + if (!isa(oprand.getType())) + if (!isa(oprand.getType()) && + oprand.getType().getIntOrFloatBitWidth() > maxOpWidth) + maxOpWidth = oprand.getType().getIntOrFloatBitWidth(); + } + + unsigned indexWidth = 1; + if (ind > 1) + indexWidth = ceil(log2(ind)); + + if (isa(vecOperands[0].getType())) { + widths.push_back({}); + widths.push_back({0, indexWidth}); + return widths; + } + + unsigned int width = std::min(CPP_MAX_WIDTH, maxOpWidth); + std::vector opwidths(ind, width); + + widths.push_back(opwidths); // matched widths for operators + widths.push_back({indexWidth}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[mlir::arith::SelectOp::getOperationName()] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + + unsigned ind = 0, maxOpWidth = 2; + + for (auto oprand : vecOperands) { + ind++; + if (ind == 0) + continue; // skip the width of the index + + if (!isa(oprand.getType())) + if (!isa(oprand.getType()) && + oprand.getType().getIntOrFloatBitWidth() > maxOpWidth) + maxOpWidth = oprand.getType().getIntOrFloatBitWidth(); + } + + widths.push_back({1}); // bool like condition + if (isa(vecOperands[1].getType())) { + widths.push_back({}); + return widths; + } + + std::vector opwidths(ind - 1, maxOpWidth); + + widths[0].insert(widths[0].end(), opwidths.begin(), + opwidths.end()); // matched widths for operators + widths.push_back({maxOpWidth}); // matched widths for result + + return widths; + }; + + mapOpNameWidth[StringRef("handshake.d_return")] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + widths.push_back({ADDRESS_WIDTH}); + if (!isa(vecResults[0].getType())) + widths.push_back({ADDRESS_WIDTH}); + else + widths.push_back({}); + return widths; + }; + + mapOpNameWidth[StringRef("handshake.d_load")] = + [&](Operation::operand_range vecOperands, + Operation::result_range vecResults) { + std::vector> widths; + widths.push_back({CPP_MAX_WIDTH, ADDRESS_WIDTH}); + widths.push_back({CPP_MAX_WIDTH, ADDRESS_WIDTH}); + return widths; + }; + + mapOpNameWidth[StringRef("handshake.d_store")] = + mapOpNameWidth[StringRef("handshake.d_load")]; +}; + +static bool setPassFlag(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](Operation *op) { return true; }) + .Default([&](auto) { return false; }); +} + +static bool setMatchFlag(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](Operation *op) { return true; }) + .Default([&](auto) { return false; }); +} + +static bool setRevertFlag(Operation *op) { + return llvm::TypeSwitch(op) + .Case( + [](Operation *op) { return true; }) + .Default([&](auto) { return false; }); +} + +bool propType(Operation *op) { + + if (isa(*op)) { + for (auto resOp : op->getResults()) + resOp.setType(op->getOperand(1).getType()); + return true; + } + + if (isa(*op)) { + op->getResult(0).setType(op->getOperand(0).getType()); + return true; + } + return false; +} + +void replaceWithPredecessor(Operation *op) { + op->getResult(0).replaceAllUsesWith(op->getOperand(0)); +} + +void replaceWithPredecessor(Operation *op, Type resType) { + Operation *sucNode = op->getOperand(0).getDefiningOp(); + + // find the index of result in vec_results + for (auto Val : sucNode->getResults()) { + if (Val == op->getOperand(0)) { + Val.setType(resType); + break; + } + } + + op->getResult(0).replaceAllUsesWith(op->getOperand(0)); +} + +void revertTruncOrExt(Operation *op , MLIRContext *ctx) { + OpBuilder builder(ctx); + // if width(res) == width(opr) : delte the operand; + + if (op ->getResult(0).getType().getIntOrFloatBitWidth() == + op ->getOperand(0).getType().getIntOrFloatBitWidth()) { + + replaceWithPredecessor(op ); + op ->erase(); + return; + } + + // if for extension operation width(res) < width(opr), + // change it to truncation operation + if (isa(*op ) || isa(*op )) + if (op ->getResult(0).getType().getIntOrFloatBitWidth() < + op ->getOperand(0).getType().getIntOrFloatBitWidth()) { + + builder.setInsertionPoint(op); + Type newType = + getNewType(op->getResult(0), + op->getResult(0).getType().getIntOrFloatBitWidth(), false); + auto truncOp = builder.create( + op->getLoc(), newType, op->getOperand(0)); + op->getResult(0).replaceAllUsesWith(truncOp.getResult()); + op->erase(); + return; + } + + // if for truncation operation width(res) > width(opr), + // change it to extension operation + if (isa(*op)) + if (op->getResult(0).getType().getIntOrFloatBitWidth() > + op->getOperand(0).getType().getIntOrFloatBitWidth()) { + + builder.setInsertionPoint(op); + Type newType = + getNewType(op->getResult(0), + op->getResult(0).getType().getIntOrFloatBitWidth(), false); + auto truncOp = builder.create(op->getLoc(), newType, + op->getOperand(0)); + op->getResult(0).replaceAllUsesWith(truncOp.getResult()); + op->erase(); + } +} + +void matchOpResWidth(Operation *op, MLIRContext *ctx, + SmallVector &newMatchedOps) { + + DenseMap>( + Operation::operand_range vecOperands, + Operation::result_range vecResults)>> + mapOpNameWidth; + + constructUpdateFuncMap(mapOpNameWidth); + + std::vector> oprsWidth = + mapOpNameWidth[op->getName().getStringRef()](op->getOperands(), + op->getResults()); + + // make operator matched the width + for (size_t i = 0; i < oprsWidth[0].size(); ++i) { + if (auto Operand = op->getOperand(i); + !isa(Operand.getType()) && + Operand.getType().getIntOrFloatBitWidth() != oprsWidth[0][i]) { + auto insertOp = insertWidthMatchOp( + op, i, getNewType(Operand, oprsWidth[0][i], false), ctx); + if (insertOp.has_value()) + newMatchedOps.push_back(insertOp.value()); + } + } + // make result matched the width + for (size_t i = 0; i < oprsWidth[1].size(); ++i) { + if (auto OpRes = op->getResult(i); + oprsWidth[1][i] != 0 && + OpRes.getType().getIntOrFloatBitWidth() != oprsWidth[1][i]) { + Type newType = getNewType(OpRes, oprsWidth[1][i], false); + op->getResult(i).setType(newType); + } + } +} + +void validateOp(Operation *op, MLIRContext *ctx, + SmallVector &newMatchedOps) { + // the operations can be divided to three types to make it validated + // passType: branch, conditionalbranch + // c <= op(a,b): addi, subi, mux, etc. where both a,b,c needed to be verified + // need to be reverted or deleted : truncIOp, extIOp + bool pass = setPassFlag(op); + bool match = setMatchFlag(op); + bool revert = setRevertFlag(op); + + if (pass) + bool res = propType(op); + + if (match) + matchOpResWidth(op, ctx, newMatchedOps); + + if (revert) + revertTruncOrExt(op, ctx); +} +} // namespace dynamatic::update \ No newline at end of file diff --git a/test/Transforms/optimize-bitwidth.mlir b/test/Transforms/optimize-bitwidth.mlir new file mode 100644 index 000000000..475de7c51 --- /dev/null +++ b/test/Transforms/optimize-bitwidth.mlir @@ -0,0 +1,124 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// RUN: dynamatic-opt --optimize-bits %s --split-input-file | FileCheck %s + +// ----- + +// CHECK-LABEL: handshake.func @optimizeAdd( +// CHECK-SAME: %[[VAL_0:.*]]: none, ...) -> i32 attributes {argNames = ["arg0"], resNames = ["out0"]} { +// CHECK: %[[VAL_1:.*]] = merge %[[VAL_0]] : none +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 999 : i11} : i11 +// CHECK: %[[VAL_3:.*]] = arith.extsi %[[VAL_2]] : i11 to i32 +// CHECK: %[[VAL_4:.*]] = constant %[[VAL_1]] {value = -2 : i2} : i2 +// CHECK: %[[VAL_5:.*]] = arith.extsi %[[VAL_4]] : i2 to i32 +// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_2]] : i11 to i12 +// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_4]] : i2 to i12 +// CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_6]], %[[VAL_7]] : i12 +// CHECK: %[[VAL_9:.*]] = arith.extsi %[[VAL_8]] : i12 to i32 +// CHECK: %[[VAL_10:.*]] = d_return %[[VAL_9]] : i32 +// CHECK: end +// CHECK: } +handshake.func @optimizeAdd(%arg0: none) -> i32 { + %0 = merge %arg0 : none + %1 = constant %0 {value = 999 : i11} : i11 + %2 = arith.extsi %1 : i11 to i32 + %3 = constant %0 {value = 2 : i2}: i2 + %4 = arith.extsi %3 : i2 to i32 + %5 = arith.addi %2, %4 : i32 + %6 = d_return %5 : i32 + end +} + +// CHECK-LABEL: handshake.func @optimizeBackwardSub( +// CHECK-SAME: %[[VAL_0:.*]]: none, ...) -> i8 attributes {argNames = ["arg0"], resNames = ["out0"]} { +// CHECK: %[[VAL_1:.*]] = merge %[[VAL_0]] : none +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 999 : i11} : i11 +// CHECK: %[[VAL_3:.*]] = arith.extsi %[[VAL_2]] : i11 to i32 +// CHECK: %[[VAL_4:.*]] = constant %[[VAL_1]] {value = 5 : i4} : i4 +// CHECK: %[[VAL_5:.*]] = arith.extsi %[[VAL_4]] : i4 to i32 +// CHECK: %[[VAL_6:.*]] = arith.trunci %[[VAL_2]] : i11 to i8 +// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_4]] : i4 to i8 +// CHECK: %[[VAL_8:.*]] = arith.subi %[[VAL_6]], %[[VAL_7]] : i8 +// CHECK: end +// CHECK: } +handshake.func @optimizeBackwardSub(%arg0: none) -> i8 { + %0 = merge %arg0 : none + %1 = constant %0 {value = 999 : i11} : i11 + %2 = arith.extsi %1 : i11 to i32 + %3 = constant %0 {value = 5 : i4} : i4 + %4 = arith.extsi %3 : i4 to i32 + %5 = arith.subi %2, %4 : i32 + %6 = arith.trunci %5 : i32 to i8 + end +} + + +// CHECK-LABEL: handshake.func @optimizeBackwardMuL( +// CHECK-SAME: %[[VAL_0:.*]]: none, ...) -> i8 attributes {argNames = ["arg0"], resNames = ["out0"]} { +// CHECK: %[[VAL_1:.*]] = merge %[[VAL_0]] : none +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 999 : i11} : i11 +// CHECK: %[[VAL_3:.*]] = arith.extsi %[[VAL_2]] : i11 to i32 +// CHECK: %[[VAL_4:.*]] = constant %[[VAL_1]] {value = 20 : i6} : i6 +// CHECK: %[[VAL_5:.*]] = arith.extsi %[[VAL_4]] : i6 to i32 +// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_2]] : i11 to i16 +// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_4]] : i6 to i16 +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_6]], %[[VAL_7]] : i16 +// CHECK: end +// CHECK: } +handshake.func @optimizeBackwardMuL(%arg0: none) -> i8 { + %0 = merge %arg0 : none + %1 = constant %0 {value = 999 : i11} : i11 + %2 = arith.extsi %1 : i11 to i32 + %3 = constant %0 {value = 20 : i6} : i6 + %4 = arith.extsi %3 : i6 to i32 + %5 = arith.muli %2, %4 : i32 + %6 = arith.trunci %5 : i32 to i16 + end +} + +// CHECK-LABEL: handshake.func @optimizeBackwardDiV( +// CHECK-SAME: %[[VAL_0:.*]]: none, ...) -> i8 attributes {argNames = ["arg0"], resNames = ["out0"]} { +// CHECK: %[[VAL_1:.*]] = merge %[[VAL_0]] : none +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 999 : i11} : i11 +// CHECK: %[[VAL_3:.*]] = arith.extsi %[[VAL_2]] : i11 to i32 +// CHECK: %[[VAL_4:.*]] = constant %[[VAL_1]] {value = 20 : i6} : i6 +// CHECK: %[[VAL_5:.*]] = arith.extsi %[[VAL_4]] : i6 to i32 +// CHECK: %[[VAL_6:.*]] = arith.extsi %[[VAL_2]] : i11 to i12 +// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_4]] : i6 to i12 +// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_6]], %[[VAL_7]] : i12 +// CHECK: end +// CHECK: } + +handshake.func @optimizeBackwardDiV(%arg0: none) -> i8 { + %0 = merge %arg0 : none + %1 = constant %0 {value = 999 : i11} : i11 + %2 = arith.extsi %1 : i11 to i32 + %3 = constant %0 {value = 20 : i6} : i6 + %4 = arith.extsi %3 : i6 to i32 + %5 = arith.divsi %2, %4 : i32 + %6 = arith.trunci %5 : i32 to i16 + end +} + +// CHECK-LABEL: handshake.func @optimizeBackwardShR( +// CHECK-SAME: %[[VAL_0:.*]]: none, ...) -> i8 attributes {argNames = ["arg0"], resNames = ["out0"]} { +// CHECK: %[[VAL_1:.*]] = merge %[[VAL_0]] : none +// CHECK: %[[VAL_2:.*]] = constant %[[VAL_1]] {value = 999 : i11} : i11 +// CHECK: %[[VAL_3:.*]] = arith.extsi %[[VAL_2]] : i11 to i32 +// CHECK: %[[VAL_4:.*]] = constant %[[VAL_1]] {value = 2 : i4} : i4 +// CHECK: %[[VAL_5:.*]] = arith.extsi %[[VAL_4]] : i4 to i32 +// CHECK: %[[VAL_6:.*]] = arith.trunci %[[VAL_2]] : i11 to i8 +// CHECK: %[[VAL_7:.*]] = arith.extsi %[[VAL_4]] : i4 to i8 +// CHECK: %[[VAL_8:.*]] = arith.shrui %[[VAL_6]], %[[VAL_7]] : i8 +// CHECK: end +// CHECK: } + +handshake.func @optimizeBackwardShR(%arg0: none) -> i8 { + %0 = merge %arg0 : none + %1 = constant %0 {value = 999 : i11} : i11 + %2 = arith.extsi %1 : i11 to i32 + %3 = constant %0 {value = 2 : i4} : i4 + %4 = arith.extsi %3 : i4 to i32 + %5 = arith.shrui %2, %4 : i32 + %6 = arith.trunci %5 : i32 to i8 + end +} \ No newline at end of file