Skip to content

Commit

Permalink
[Bitwidth] Bitwidth optimization pass #1
Browse files Browse the repository at this point in the history
This commit introduces a bitwidth optimization pass similar to the one implemented by legacy Dynamatic. The pass iteratively optimizes the bitwidth of all values present in the handshake-level IR by iterating between a forward analysis step (which optimizes the bitwidth of results based on the bitwidth of operands) and a backward analysis step (which optimizes the bitwidth of operands based on the bitwidth of results) until the IR converges to its optimized form.

The commit also introduces two helper passes that are pre-requisites to the bitwidth optimization pass itself. `HandshakeInitIndType` concretizes the width of all `IndexType`'s in the IR while `HandshakeInitCstWidth` modifies the width of all constants to be as narrow as possible.
  • Loading branch information
yuxwang99 authored May 30, 2023
1 parent 6805d61 commit abd3f1a
Show file tree
Hide file tree
Showing 13 changed files with 1,370 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
# Build-related directories
build/
bin/
.cache/
.cache/
18 changes: 18 additions & 0 deletions include/dynamatic/Transforms/InitCstWidth.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//===- InitCstWidth.h - Reduce the constant bits width ----------*- C++ -*-===//
//
// This file declares the --init-cstwidth pass.
//
//===----------------------------------------------------------------------===//

#ifndef DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H
#define DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H

#include "dynamatic/Transforms/UtilsBitsUpdate.h"

namespace dynamatic {

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createInitCstWidthPass();

} // namespace dynamatic

#endif // DYNAMATIC_TRANSFORMS_INITCSTWIDTH_H
18 changes: 18 additions & 0 deletions include/dynamatic/Transforms/InitIndexType.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//===- InitIndexType.h - Transform IndexType to IntegerType -----*- C++ -*-===//
//
// This file declares the --init-indextype pass.
//
//===----------------------------------------------------------------------===//

#ifndef DYNAMATIC_TRANSFORMS_INITINDTYPE_H
#define DYNAMATIC_TRANSFORMS_INITINDTYPE_H

#include "dynamatic/Transforms/UtilsBitsUpdate.h"

namespace dynamatic {

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createInitIndTypePass();

} // namespace dynamatic

#endif // DYNAMATIC_TRANSFORMS_INITINDTYPE_H
18 changes: 18 additions & 0 deletions include/dynamatic/Transforms/OptimizeBits.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//===- BitsOptimize.h - Optimize bits widths --------------------*- C++ -*-===//
//
// This file declares the --optimize-bits pass.
//
//===----------------------------------------------------------------------===//

#ifndef DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H
#define DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H

#include "dynamatic/Transforms/UtilsBitsUpdate.h"

namespace dynamatic {

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createOptimizeBitsPass();

} // namespace dynamatic

#endif // DYNAMATIC_TRANSFORMS_OPTIMIZEBITS_H
5 changes: 4 additions & 1 deletion include/dynamatic/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@

#include "dynamatic/Support/LLVM.h"
#include "dynamatic/Transforms/AnalyzeMemoryAccesses.h"
#include "dynamatic/Transforms/OptimizeBits.h"
#include "dynamatic/Transforms/ArithReduceArea.h"
#include "dynamatic/Transforms/FlattenMemRefRowMajor.h"
#include "dynamatic/Transforms/HandshakeInferBasicBlocks.h"
#include "dynamatic/Transforms/HandshakePrepareForLegacy.h"
#include "dynamatic/Transforms/InitCstWidth.h"
#include "dynamatic/Transforms/InitIndexType.h"
#include "dynamatic/Transforms/NameMemoryOps.h"
#include "dynamatic/Transforms/PushConstants.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -25,4 +28,4 @@ namespace dynamatic {

} // namespace dynamatic

#endif // DYNAMATIC_TRANSFORMS_PASSES_H
#endif // DYNAMATIC_TRANSFORMS_PASSES_H
28 changes: 28 additions & 0 deletions include/dynamatic/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,32 @@ def PushConstants : Pass<"push-constants", "mlir::ModuleOp"> {
let constructor = "dynamatic::createPushConstantsPass()";
}

def HandshakeOptimizeBits : Pass<"optimize-bits" , "mlir::ModuleOp"> {
let summary = "Optimize bits width that can be reduced.";
let description = [{
This pass goes through all operations inside handshake levels, and optimize the bit width through the forward (from first operation to last), and backward(from last operation to first) loop process.
The loop continues until no more optimization can be done.
}];
let constructor = "dynamatic::createOptimizeBitsPass()";
let dependentDialects = ["mlir::arith::ArithDialect"];
}

def HandshakeInitIndType : Pass<"init-indtype", "mlir::ModuleOp"> {
let summary = "Initialize the index type of the module.";
let description = [{
This pass change all the index type within operands and result operands to integer type with platform dependent bit width.
}];
let constructor = "dynamatic::createInitIndTypePass()";
}

def HandshakeInitCstWidth : Pass<"init-cstwidth", "mlir::ModuleOp"> {
let summary = "Initialize the constant bit witth of the module.";
let description = [{
This pass rewrites constant operation with the minimum required bit width according to the value of the constant. To ensure the consistency with the user of the constant operation, the pass inserts a extension operation if necessary.
}];
let constructor = "dynamatic::createInitCstWidthPass()";
let dependentDialects = ["mlir::arith::ArithDialect"];

}

#endif // DYNAMATIC_TRANSFORMS_PASSES_TD
86 changes: 86 additions & 0 deletions include/dynamatic/Transforms/UtilsBitsUpdate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//===- UtilsBitsUpdate.h - Utils support bits optimization ------*- C++ -*-===//
//
// This file declares supports for --optimize-bits pass.
//
//===----------------------------------------------------------------------===//

#ifndef DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H
#define DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H

#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "dynamatic/Support/LLVM.h"
#include <optional>

using namespace mlir;
using namespace circt;
using namespace circt::handshake;
using namespace dynamatic;

const unsigned CPP_MAX_WIDTH = 64;
const unsigned ADDRESS_WIDTH = 32;

IntegerType getNewType(Value opVal, unsigned bitswidth, bool signless = false);

IntegerType getNewType(Value opVal, unsigned bitswidth,
IntegerType::SignednessSemantics ifSign);

std::optional<Operation *> insertWidthMatchOp(Operation *newOp, int opInd,
Type newType, MLIRContext *ctx);

namespace dynamatic {
namespace bitwidth {

/// Construct the functions w.r.t. the operation name in the forward process.
void constructForwardFuncMap(
DenseMap<StringRef,
std::function<unsigned(Operation::operand_range vecOperands)>>
&mapOpNameWidth);

/// Construct the functions w.r.t. the operation name in the backward process.
void constructBackwardFuncMap(
DenseMap<StringRef,
std::function<unsigned(Operation::result_range vecResults)>>
&mapOpNameWidth);

/// Construct the functions w.r.t. the operation name in the validation process.
void constructUpdateFuncMap(
DenseMap<mlir::StringRef,
std::function<unsigned(Operation::operand_range vecOperands)>>
&mapOpNameWidth);

/// For branch and conditional branch operations, propagate the bits width of
/// the operands to the result.
bool propType(Operation *Op);

/// Insert width match operations (extension or truncation) for the operands and
/// the results.
void matchOpResWidth(Operation *op, MLIRContext *ctx,
SmallVector<Operation *> &newMatchedOps);

/// Replace the operation's operand with the its successor.
void replaceWithPredecessor(Operation *op);

/// Replace the operation's operand with the its successor, and set the operation's
/// resultOp according to its successor's resultOp type.
void replaceWithPredecessor(Operation *op, Type resType);

// Validate the truncation and extension operation in case its operand and
// result operand width are not consistent by reverting or deleting the
// operations.
void revertTruncOrExt(Operation *op, MLIRContext *ctx);

/// Set the pass, match, and revert flags to choose the methods that validate
/// the operations.
static bool setPassFlag(Operation *op);

static bool setMatchFlag(Operation *op);

static bool setRevertFlag(Operation *op);

/// Validate the operations after bits optimization to generate .mlir file.
void validateOp(Operation *op, MLIRContext *ctx,
SmallVector<Operation *> &newMatchedOps);
} // namespace bitwidth
} // namespace dynamatic

#endif // DYNAMATIC_TRANSFORMS_UTILSBITSUPDATE_H
4 changes: 4 additions & 0 deletions lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ add_dynamatic_library(DynamaticTransforms
HandshakeInferBasicBlocks.cpp
NameMemoryOps.cpp
PushConstants.cpp
UtilsBitsUpdate.cpp
OptimizeBits.cpp
InitIndexType.cpp
InitCstWidth.cpp

DEPENDS
DynamaticTransformsPassIncGen
Expand Down
98 changes: 98 additions & 0 deletions lib/Transforms/InitCstWidth.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//===- InitCstWidth.cpp - Reduce the constant bits width --------*- C++ -*-===//
//
// This file contains the implementation of the init-cstwidth pass.
//
//===----------------------------------------------------------------------===//

#include "dynamatic/Transforms/InitCstWidth.h"
#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "dynamatic/Transforms/PassDetails.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/IndentedOstream.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "BITWIDTH"

static LogicalResult initCstOpBitsWidth(handshake::FuncOp funcOp,
MLIRContext *ctx) {
OpBuilder builder(ctx);
SmallVector<handshake::ConstantOp> cstOps;

int savedBits = 0;

for (auto op :
llvm::make_early_inc_range(funcOp.getOps<handshake::ConstantOp>())) {
unsigned cstBitWidth = CPP_MAX_WIDTH;
IntegerType::SignednessSemantics ifSign =
IntegerType::SignednessSemantics::Signless;
// skip the bool value constant operation
if (!isa<mlir::IntegerAttr>(op.getValue()))
continue;

// get the attribute value
if (auto intAttr = dyn_cast<mlir::IntegerAttr>(op.getValue())) {
if (int cstVal = intAttr.getValue().getZExtValue(); cstVal > 0)
cstBitWidth = log2(cstVal) + 2;
else if (int cstVal = intAttr.getValue().getZExtValue(); cstVal < 0) {
cstBitWidth = log2(-cstVal) + 2;
} else
cstBitWidth = 2;
}

if (cstBitWidth < op.getResult().getType().getIntOrFloatBitWidth()) {
// Get the new type of calculated bitwidth
Type newType = getNewType(op.getResult(), cstBitWidth, ifSign);

// Update the constant operator for both ValueAttr and result Type
builder.setInsertionPointAfter(op);
handshake::ConstantOp newCstOp = builder.create<handshake::ConstantOp>(
op.getLoc(), newType, op.getValue(), op.getCtrl());

// Determine the proper representation of the constant value
int intVal = op.getValue().cast<IntegerAttr>().getInt();
intVal =
((1 << op.getValue().getType().getIntOrFloatBitWidth()) - 1 + intVal);
newCstOp.setValueAttr(IntegerAttr::get(newType, intVal));
// save the original bb
newCstOp->setAttr("bb", op->getAttr("bb"));

// recursively replace the uses of the old constant operation with the new
// one Value opVal = op.getResult();
savedBits +=
op.getResult().getType().getIntOrFloatBitWidth() - cstBitWidth;
auto extOp = builder.create<mlir::arith::ExtSIOp>(
newCstOp.getLoc(), op.getResult().getType(), newCstOp.getResult());

// replace the constant operation (default width)
// with the extension of new constant operation (optimized width)
op->replaceAllUsesWith(extOp);
}
}

LLVM_DEBUG(llvm::dbgs() << "Number of saved bits is " << savedBits << "\n");

return success();
}

struct HandshakeInitCstWidthPass
: public HandshakeInitCstWidthBase<HandshakeInitCstWidthPass> {

void runOnOperation() override {
auto *ctx = &getContext();

ModuleOp m = getOperation();
for (auto funcOp : m.getOps<handshake::FuncOp>())
if (failed(initCstOpBitsWidth(funcOp, ctx)))
return signalPassFailure();
};
};

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
dynamatic::createInitCstWidthPass() {
return std::make_unique<HandshakeInitCstWidthPass>();
}
Loading

0 comments on commit abd3f1a

Please sign in to comment.