-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Bitwidth] Bitwidth optimization pass #1
This commit introduces a bitwidth optimization pass similar to the one implemented by legacy Dynamatic. The pass iteratively optimizes the bitwidth of all values present in the handshake-level IR by iterating between a forward analysis step (which optimizes the bitwidth of results based on the bitwidth of operands) and a backward analysis step (which optimizes the bitwidth of operands based on the bitwidth of results) until the IR converges to its optimized form. The commit also introduces two helper passes that are pre-requisites to the bitwidth optimization pass itself. `HandshakeInitIndType` concretizes the width of all `IndexType`'s in the IR while `HandshakeInitCstWidth` modifies the width of all constants to be as narrow as possible.
- Loading branch information
Showing
13 changed files
with
1,370 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,4 +34,4 @@ | |
# Build-related directories | ||
build/ | ||
bin/ | ||
.cache/ | ||
.cache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
//===- InitCstWidth.h - Reduce the constant bits width ----------*- C++ -*-===// | ||
// | ||
// This file declares the --init-cstwidth pass. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H | ||
#define DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H | ||
|
||
#include "dynamatic/Transforms/UtilsBitsUpdate.h" | ||
|
||
namespace dynamatic { | ||
|
||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createInitCstWidthPass(); | ||
|
||
} // namespace dynamatic | ||
|
||
#endif // DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
//===- InitIndexType.h - Transform IndexType to IntegerType -----*- C++ -*-===// | ||
// | ||
// This file declares the --init-indextype pass. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef DYNAMATIC_TRANSFORMS_INITINDTYPE_H | ||
#define DYNAMATIC_TRANSFORMS_INITINDTYPE_H | ||
|
||
#include "dynamatic/Transforms/UtilsBitsUpdate.h" | ||
|
||
namespace dynamatic { | ||
|
||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createInitIndTypePass(); | ||
|
||
} // namespace dynamatic | ||
|
||
#endif // DYNAMATIC_TRANSFORMS_INITINDTYPE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
//===- BitsOptimize.h - Optimize bits widths --------------------*- C++ -*-===// | ||
// | ||
// This file declares the --optimize-bits pass. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H | ||
#define DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H | ||
|
||
#include "dynamatic/Transforms/UtilsBitsUpdate.h" | ||
|
||
namespace dynamatic { | ||
|
||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createOptimizeBitsPass(); | ||
|
||
} // namespace dynamatic | ||
|
||
#endif // DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
//===- UtilsBitsUpdate.h - Utils support bits optimization ------*- C++ -*-===// | ||
// | ||
// This file declares supports for --optimize-bits pass. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H | ||
#define DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H | ||
|
||
#include "circt/Dialect/Handshake/HandshakeOps.h" | ||
#include "dynamatic/Support/LLVM.h" | ||
#include <optional> | ||
|
||
using namespace mlir; | ||
using namespace circt; | ||
using namespace circt::handshake; | ||
using namespace dynamatic; | ||
|
||
const unsigned CPP_MAX_WIDTH = 64; | ||
const unsigned ADDRESS_WIDTH = 32; | ||
|
||
IntegerType getNewType(Value opVal, unsigned bitswidth, bool signless = false); | ||
|
||
IntegerType getNewType(Value opVal, unsigned bitswidth, | ||
IntegerType::SignednessSemantics ifSign); | ||
|
||
std::optional<Operation *> insertWidthMatchOp(Operation *newOp, int opInd, | ||
Type newType, MLIRContext *ctx); | ||
|
||
namespace dynamatic { | ||
namespace bitwidth { | ||
|
||
/// Construct the functions w.r.t. the operation name in the forward process. | ||
void constructForwardFuncMap( | ||
DenseMap<StringRef, | ||
std::function<unsigned(Operation::operand_range vecOperands)>> | ||
&mapOpNameWidth); | ||
|
||
/// Construct the functions w.r.t. the operation name in the backward process. | ||
void constructBackwardFuncMap( | ||
DenseMap<StringRef, | ||
std::function<unsigned(Operation::result_range vecResults)>> | ||
&mapOpNameWidth); | ||
|
||
/// Construct the functions w.r.t. the operation name in the validation process. | ||
void constructUpdateFuncMap( | ||
DenseMap<mlir::StringRef, | ||
std::function<unsigned(Operation::operand_range vecOperands)>> | ||
&mapOpNameWidth); | ||
|
||
/// For branch and conditional branch operations, propagate the bits width of | ||
/// the operands to the result. | ||
bool propType(Operation *Op); | ||
|
||
/// Insert width match operations (extension or truncation) for the operands and | ||
/// the results. | ||
void matchOpResWidth(Operation *op, MLIRContext *ctx, | ||
SmallVector<Operation *> &newMatchedOps); | ||
|
||
/// Replace the operation's operand with the its successor. | ||
void replaceWithPredecessor(Operation *op); | ||
|
||
/// Replace the operation's operand with the its successor, and set the operation's | ||
/// resultOp according to its successor's resultOp type. | ||
void replaceWithPredecessor(Operation *op, Type resType); | ||
|
||
// Validate the truncation and extension operation in case its operand and | ||
// result operand width are not consistent by reverting or deleting the | ||
// operations. | ||
void revertTruncOrExt(Operation *op, MLIRContext *ctx); | ||
|
||
/// Set the pass, match, and revert flags to choose the methods that validate | ||
/// the operations. | ||
static bool setPassFlag(Operation *op); | ||
|
||
static bool setMatchFlag(Operation *op); | ||
|
||
static bool setRevertFlag(Operation *op); | ||
|
||
/// Validate the operations after bits optimization to generate .mlir file. | ||
void validateOp(Operation *op, MLIRContext *ctx, | ||
SmallVector<Operation *> &newMatchedOps); | ||
} // namespace bitwidth | ||
} // namespace dynamatic | ||
|
||
#endif // DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
//===- InitCstWidth.cpp - Reduce the constant bits width --------*- C++ -*-===// | ||
// | ||
// This file contains the implementation of the init-cstwidth pass. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "dynamatic/Transforms/InitCstWidth.h" | ||
#include "circt/Dialect/Handshake/HandshakeOps.h" | ||
#include "dynamatic/Transforms/PassDetails.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/Dialect.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/IR/OperationSupport.h" | ||
#include "mlir/Support/IndentedOstream.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
||
#define DEBUG_TYPE "BITWIDTH" | ||
|
||
static LogicalResult initCstOpBitsWidth(handshake::FuncOp funcOp, | ||
MLIRContext *ctx) { | ||
OpBuilder builder(ctx); | ||
SmallVector<handshake::ConstantOp> cstOps; | ||
|
||
int savedBits = 0; | ||
|
||
for (auto op : | ||
llvm::make_early_inc_range(funcOp.getOps<handshake::ConstantOp>())) { | ||
unsigned cstBitWidth = CPP_MAX_WIDTH; | ||
IntegerType::SignednessSemantics ifSign = | ||
IntegerType::SignednessSemantics::Signless; | ||
// skip the bool value constant operation | ||
if (!isa<mlir::IntegerAttr>(op.getValue())) | ||
continue; | ||
|
||
// get the attribute value | ||
if (auto intAttr = dyn_cast<mlir::IntegerAttr>(op.getValue())) { | ||
if (int cstVal = intAttr.getValue().getZExtValue(); cstVal > 0) | ||
cstBitWidth = log2(cstVal) + 2; | ||
else if (int cstVal = intAttr.getValue().getZExtValue(); cstVal < 0) { | ||
cstBitWidth = log2(-cstVal) + 2; | ||
} else | ||
cstBitWidth = 2; | ||
} | ||
|
||
if (cstBitWidth < op.getResult().getType().getIntOrFloatBitWidth()) { | ||
// Get the new type of calculated bitwidth | ||
Type newType = getNewType(op.getResult(), cstBitWidth, ifSign); | ||
|
||
// Update the constant operator for both ValueAttr and result Type | ||
builder.setInsertionPointAfter(op); | ||
handshake::ConstantOp newCstOp = builder.create<handshake::ConstantOp>( | ||
op.getLoc(), newType, op.getValue(), op.getCtrl()); | ||
|
||
// Determine the proper representation of the constant value | ||
int intVal = op.getValue().cast<IntegerAttr>().getInt(); | ||
intVal = | ||
((1 << op.getValue().getType().getIntOrFloatBitWidth()) - 1 + intVal); | ||
newCstOp.setValueAttr(IntegerAttr::get(newType, intVal)); | ||
// save the original bb | ||
newCstOp->setAttr("bb", op->getAttr("bb")); | ||
|
||
// recursively replace the uses of the old constant operation with the new | ||
// one Value opVal = op.getResult(); | ||
savedBits += | ||
op.getResult().getType().getIntOrFloatBitWidth() - cstBitWidth; | ||
auto extOp = builder.create<mlir::arith::ExtSIOp>( | ||
newCstOp.getLoc(), op.getResult().getType(), newCstOp.getResult()); | ||
|
||
// replace the constant operation (default width) | ||
// with the extension of new constant operation (optimized width) | ||
op->replaceAllUsesWith(extOp); | ||
} | ||
} | ||
|
||
LLVM_DEBUG(llvm::dbgs() << "Number of saved bits is " << savedBits << "\n"); | ||
|
||
return success(); | ||
} | ||
|
||
struct HandshakeInitCstWidthPass | ||
: public HandshakeInitCstWidthBase<HandshakeInitCstWidthPass> { | ||
|
||
void runOnOperation() override { | ||
auto *ctx = &getContext(); | ||
|
||
ModuleOp m = getOperation(); | ||
for (auto funcOp : m.getOps<handshake::FuncOp>()) | ||
if (failed(initCstOpBitsWidth(funcOp, ctx))) | ||
return signalPassFailure(); | ||
}; | ||
}; | ||
|
||
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> | ||
dynamatic::createInitCstWidthPass() { | ||
return std::make_unique<HandshakeInitCstWidthPass>(); | ||
} |
Oops, something went wrong.