Skip to content

Commit

Permalink
[export-dot] Use timing modeling in DOT legacy mode (#34)
Browse files Browse the repository at this point in the history
This commit makes the `export-dot` tool leverage the new timing modeling
infrastructure (the same as for buffer placement) when producing DOTs
compatible with legacy Dynamatic. Instead of component delays and
latencies being (probably incorrectly) hardcoded inside the DOTPrinter,
they are now fetched from a timing database whose path is given as an
argument to the tool.

Indirectly, this allows us to make meaningful comparisons between the
output of the new buffer placement pass and the legacy one, which now
both get their component timing models fetched from the same source.

Fixes #18.
  • Loading branch information
Lucas Ramirez authored Sep 7, 2023
1 parent b71d83c commit b8a982d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 112 deletions.
33 changes: 26 additions & 7 deletions include/dynamatic/Support/DOTPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
//
//===----------------------------------------------------------------------===//

#ifndef DYNAMATIC_SUPPORT_DOTPRINTER_H
#define DYNAMATIC_SUPPORT_DOTPRINTER_H

#include "circt/Dialect/Handshake/HandshakeOps.h"
#include "dynamatic/Support/LLVM.h"
#include "dynamatic/Support/TimingModels.h"
#include "mlir/Support/IndentedOstream.h"
#include <map>
#include <set>
Expand All @@ -23,24 +27,27 @@ struct EdgeInfo;

/// Implements the logic to convert Handshake-level IR to a DOT. The only public
/// method of this class, printDOT, converts an MLIR module containing a single
/// Handshake function into an equivalent DOT graph. In legacy mode, the
/// resulting DOT can be used with legacy Dynamatic.
/// Handshake function into an equivalent DOT graph printed on stdout. In legacy
/// mode, the resulting DOT can be used with legacy Dynamatic.
class DOTPrinter {
public:
/// Constructs a DOTPrinter whose printing behavior is controlled by a couple
/// flags.
DOTPrinter(bool legacy, bool debug);
/// flags, plus a pointer to a timing database that must be valid in legacy
/// mode (when building Dynamatic++ in debug mode, the constructor will assert
/// if the `legacy` flag is true and the timing database is nullptr).
DOTPrinter(bool legacy, bool debug, TimingDatabase *timingDB = nullptr);

/// Prints Handshake-level IR to standard output.
LogicalResult printDOT(mlir::ModuleOp mod);

private:
/// Whether to export a legacy-compatible DOT.
bool legacy;

/// Whether to pretty-print the exported DOT (pretty-print if false).
bool debug;

/// Timing models for dataflow components (required in legacy mode, can safely
/// be nullptr when not in legacy mode).
TimingDatabase *timingDB;
/// The stream to output to.
mlir::raw_indented_ostream os;

Expand Down Expand Up @@ -82,6 +89,16 @@ class DOTPrinter {
/// Returns the name of the node representing the operation.
std::string getNodeName(Operation *op);

/// Returns the content of the "delay" attribute associated to every graph
/// node in legacy mode. Requires that `timingDB` points to a valid memory
/// location.
std::string getNodeDelayAttr(Operation *op);

/// Returns the content of the "latency" attribute associated to every graph
/// node in legacy mode. Requires that `timingDB` points to a valid memory
/// location.
std::string getNodeLatencyAttr(Operation *op);

/// Computes all data attributes of an operation for use in legacy Dynamatic
/// and prints them to the output stream; it is the responsibility of the
/// caller of this method to insert an opening bracket before the call and a
Expand Down Expand Up @@ -175,4 +192,6 @@ struct EdgeInfo {
void print(mlir::raw_indented_ostream &os);
};

} // namespace dynamatic
} // namespace dynamatic

#endif // DYNAMATIC_SUPPORT_DOTPRINTER_H
142 changes: 38 additions & 104 deletions lib/Support/DOTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include <iomanip>
#include <string>

using namespace circt;
using namespace circt::handshake;
Expand Down Expand Up @@ -53,50 +54,6 @@ static std::unordered_map<std::string, std::string> arithNameToOpName{
{"arith.trunci", "trunc_op"}, {"arith.shrsi", "ashr_op"},
{"arith.shli", "shl_op"}, {"arith.select", "select_op"}};

/// Delay information for arith.addi and arith.subi operations.
static const std::string DELAY_ADD_SUB =
"2.287 1.397 1.400 1.409 100.000 100.000 100.000 100.000";
/// Delay information for arith.muli and arith.divi operations.
static const std::string DELAY_MUL_DIV =
"2.287 1.397 1.400 1.409 100.000 100.000 100.000 100.000";
/// Delay information for arith.subf and arith.mulf operations.
static const std::string DELAY_SUBF_MULF =
"0.000 0.000 1.400 1.411 100.000 100.000 100.000 100.000";
/// Delay information for arith.andi, arith.ori, and arith.xori operations.
static const std::string DELAY_LOGIC_OP =
"1.397 1.397 1.400 1.409 100.000 100.000 100.000 100.000";
/// Delay information for arith.sitofp and arith.remsi operations.
static const std::string DELAY_SITOFP_REMSI =
"1.412 1.397 0.000 1.412 1.397 1.412 100.000 100.000";
/// Delay information for extension and truncation operations.
static const std::string DELAY_EXT_TRUNC =
"0.672 0.672 1.397 1.397 100.000 100.000 100.000 100.000";

/// Maps name of arithmetic operation to "delay" attribute.
static std::unordered_map<std::string, std::string> arithNameToDelay{
{"arith.subi", DELAY_ADD_SUB},
{"arith.addi", DELAY_ADD_SUB},
{"arith.muli", DELAY_MUL_DIV},
{"arith.addf", "0.000,0.000,0.000,100.000,100.000,100.000,100.000,100.000"},
{"arith.subf", DELAY_SUBF_MULF},
{"arith.mulf", DELAY_SUBF_MULF},
{"arith.divui", DELAY_MUL_DIV},
{"arith.divsi", DELAY_MUL_DIV},
{"arith.divf", "0.000 0.000 1.400 100.000 100.000 100.000 100.000 100.000"},
{"arith.andi", DELAY_LOGIC_OP},
{"arith.ori", DELAY_LOGIC_OP},
{"arith.xori", DELAY_LOGIC_OP},
{"arith.sitofp", DELAY_SITOFP_REMSI},
{"arith.remsi", DELAY_SITOFP_REMSI},
{"arith.sext", DELAY_EXT_TRUNC},
{"arith.extsi", DELAY_EXT_TRUNC},
{"arith.extui", DELAY_EXT_TRUNC},
{"arith.trunci", DELAY_EXT_TRUNC},
{"arith.shrsi", DELAY_EXT_TRUNC},
{"arith.shli", DELAY_EXT_TRUNC},
{"arith.select",
"1.397 1.397 1.412 2.061 100.000 100.000 100.000 100.000"}};

/// Maps name of integer comparison type to "op" attribute.
static std::unordered_map<arith::CmpIPredicate, std::string> cmpINameToOpName{
{arith::CmpIPredicate::eq, "icmp_eq_op"},
Expand Down Expand Up @@ -606,33 +563,22 @@ LogicalResult DOTPrinter::verifyDOT(handshake::FuncOp funcOp,
LogicalResult DOTPrinter::annotateNode(Operation *op) {
auto info =
llvm::TypeSwitch<Operation *, NodeInfo>(op)
.Case<handshake::MergeOp>([&](auto) {
auto info = NodeInfo("Merge");
info.stringAttr["delay"] =
"1.397 1.412 0.000 100.000 100.000 100.000 100.000 100.000";
return info;
})
.Case<handshake::MergeOp>([&](auto) { return NodeInfo("Merge"); })
.Case<handshake::MuxOp>([&](handshake::MuxOp op) {
auto info = NodeInfo("Mux");
info.stringAttr["in"] = getInputForMux(op);
info.stringAttr["delay"] =
"1.412 1.397 0.000 1.412 1.397 1.412 100.000 100.000";
return info;
})
.Case<handshake::ControlMergeOp>([&](handshake::ControlMergeOp op) {
auto info = NodeInfo("CntrlMerge");
info.stringAttr["out"] = getOutputForControlMerge(op);
info.stringAttr["delay"] =
"0.000 1.397 0.000 100.000 100.000 100.000 100.000 100.000";
return info;
})
.Case<handshake::ConditionalBranchOp>(
[&](handshake::ConditionalBranchOp op) {
auto info = NodeInfo("Branch");
info.stringAttr["in"] = getInputForCondBranch(op);
info.stringAttr["out"] = getOutputForCondBranch(op);
info.stringAttr["delay"] =
"0.000 1.409 1.411 1.412 1.400 1.412 100.000 100.000";
return info;
})
.Case<handshake::BufferOp>([&](handshake::BufferOp bufOp) {
Expand Down Expand Up @@ -669,9 +615,6 @@ LogicalResult DOTPrinter::annotateNode(Operation *op) {
info.stringAttr["in"] = getInputForLoadOp(op);
info.stringAttr["out"] = getOutputForLoadOp(op);
info.intAttr["portId"] = findMemoryPort(op.getAddressResult());
info.intAttr["latency"] = 2;
info.stringAttr["delay"] =
"1.412 1.409 0.000 100.000 100.000 100.000 100.000 100.000";
return info;
})
.Case<handshake::DynamaticStoreOp>(
Expand All @@ -681,16 +624,9 @@ LogicalResult DOTPrinter::annotateNode(Operation *op) {
info.stringAttr["in"] = getInputForStoreOp(op);
info.stringAttr["out"] = getOutputForStoreOp(op);
info.intAttr["portId"] = findMemoryPort(op.getAddressResult());
info.stringAttr["delay"] =
"0.672 1.397 1.400 1.409 100.000 100.000 100.000 100.000";
return info;
})
.Case<handshake::ForkOp>([&](auto) {
auto info = NodeInfo("Fork");
info.stringAttr["delay"] =
"0.000 0.100 0.100 100.000 100.000 100.000 100.000 100.000";
return info;
})
.Case<handshake::ForkOp>([&](auto) { return NodeInfo("Fork"); })
.Case<handshake::SourceOp>([&](auto) {
auto info = NodeInfo("Source");
info.stringAttr["out"] = getIOFromValues(op->getResults(), "out");
Expand Down Expand Up @@ -729,15 +665,11 @@ LogicalResult DOTPrinter::annotateNode(Operation *op) {
info.stringAttr["out"] = getIOFromValues(op->getResults(), "out");

info.stringAttr["value"] = stream.str();
info.stringAttr["delay"] =
"0.000 0.000 0.000 100.000 100.000 100.000 100.000 100.000";
return info;
})
.Case<handshake::DynamaticReturnOp>([&](auto) {
auto info = NodeInfo("Operator");
info.stringAttr["op"] = "ret_op";
info.stringAttr["delay"] =
"1.412 1.409 0.000 100.000 100.000 100.000 100.000 100.000";
return info;
})
.Case<handshake::EndOp>([&](handshake::EndOp op) {
Expand All @@ -751,16 +683,12 @@ LogicalResult DOTPrinter::annotateNode(Operation *op) {
for (auto [idx, res] : llvm::enumerate(funcOp.getResultTypes()))
stream << "out" << (idx + 1) << ":" << getWidth(res);
info.stringAttr["out"] = stream.str();

info.stringAttr["delay"] =
"1.397 0.000 1.397 1.409 100.000 100.000 100.000 100.000";
return info;
})
.Case<arith::SelectOp>([&](arith::SelectOp op) {
auto info = NodeInfo("Operator");
auto opName = op->getName().getStringRef().str();
info.stringAttr["op"] = arithNameToOpName[opName];
info.stringAttr["delay"] = arithNameToDelay[opName];
info.stringAttr["in"] = getInputForSelect(op);
return info;
})
Expand All @@ -772,41 +700,21 @@ LogicalResult DOTPrinter::annotateNode(Operation *op) {
auto info = NodeInfo("Operator");
auto opName = op->getName().getStringRef().str();
info.stringAttr["op"] = arithNameToOpName[opName];
info.stringAttr["delay"] = arithNameToDelay[opName];

// Set non-zero latencies
if (opName == "arith.divui" || opName == "arith.divsi")
info.intAttr["latency"] = 36;
else if (opName == "arith.muli")
info.intAttr["latency"] = 4;
else if (opName == "arith.fadd" || opName == "arith.fsub")
info.intAttr["latency"] = 10;
else if (opName == "arith.divf")
info.intAttr["latency"] = 30;
else if (opName == "arith.mulf")
info.intAttr["latency"] = 6;

return info;
})
.Case<arith::CmpIOp>([&](arith::CmpIOp op) {
auto info = NodeInfo("Operator");
info.stringAttr["op"] = cmpINameToOpName[op.getPredicate()];
info.stringAttr["delay"] =
"1.907 1.397 1.400 1.409 100.000 100.000 100.000 100.000";
return info;
})
.Case<arith::CmpFOp>([&](arith::CmpFOp op) {
auto info = NodeInfo("Operator");
info.stringAttr["op"] = cmpFNameToOpName[op.getPredicate()];
info.intAttr["latency"] = 2;
info.stringAttr["latency"] =
"1.895 1.397 1.406 1.411 100.000 100.000 100.000 100.000";
return info;
})
.Case<arith::IndexCastOp>([&](auto) {
auto info = NodeInfo("Operator");
info.stringAttr["op"] = "zext_op";
info.stringAttr["delay"] = DELAY_EXT_TRUNC;
return info;
})
.Default([&](auto) { return NodeInfo(""); });
Expand Down Expand Up @@ -834,9 +742,9 @@ LogicalResult DOTPrinter::annotateNode(Operation *op) {
}

// Add default latency for operators if not specified
if (info.intAttr.find("latency") == info.intAttr.end() &&
info.type == "Operator")
info.intAttr["latency"] = 0;
info.stringAttr["delay"] = getNodeDelayAttr(op);
if (info.type == "Operator")
info.stringAttr["latency"] = getNodeLatencyAttr(op);

// II is 1 for all operators
if (info.type == "Operator")
Expand Down Expand Up @@ -933,6 +841,31 @@ LogicalResult DOTPrinter::annotateArgumentEdge(handshake::FuncOp funcOp,
return success();
}

std::string DOTPrinter::getNodeDelayAttr(Operation *op) {
const TimingModel *model = timingDB->getModel(op);
if (!model)
return "0.000 0.000 0.000 0.000 0.000 0.000 0.000 0.000";

double dataDelay;
if (failed(model->dataDelay.getCeilMetric(op, dataDelay)))
dataDelay = 0.0;

std::stringstream stream;
stream << std::fixed << std::setprecision(3) << dataDelay << " "
<< model->validDelay << " " << model->readyDelay << " "
<< model->validToReady << " " << model->condToValid << " "
<< model->condToReady << " " << model->validToCond << " "
<< model->validToData;
return stream.str();
}

std::string DOTPrinter::getNodeLatencyAttr(Operation *op) {
double latency;
if (failed(timingDB->getLatency(op, latency)))
return "0";
return std::to_string(static_cast<unsigned>(latency));
}

// ============================================================================
// Printing
// ============================================================================
Expand Down Expand Up @@ -973,10 +906,9 @@ static std::string getPrettyPrintedNodeLabel(Operation *op) {
})
.Case<handshake::ControlMergeOp>([&](auto) { return "cmerge"; })
.Case<handshake::ConditionalBranchOp>([&](auto) { return "cbranch"; })
.Case<handshake::BufferOp>([&](auto op) {
std::string n = "buffer ";
n += stringifyEnum(op.getBufferType());
return n;
.Case<handshake::BufferOp>([&](handshake::BufferOp bufOp) {
return stringifyEnum(bufOp.getBufferType()).str() + " [" +
std::to_string(bufOp.getNumSlots()) + "]";
})
.Case<handshake::BranchOp>([&](auto) { return "branch"; })
// handshake operations (dynamatic)
Expand Down Expand Up @@ -1065,8 +997,10 @@ static std::string getPrettyPrintedNodeLabel(Operation *op) {
});
}

DOTPrinter::DOTPrinter(bool legacy, bool debug)
: legacy(legacy), debug(debug), os(llvm::outs()){};
DOTPrinter::DOTPrinter(bool legacy, bool debug, TimingDatabase *timingDB)
: legacy(legacy), debug(debug), timingDB(timingDB), os(llvm::outs()) {
assert(!legacy || timingDB && "timing database must exist in legacy mode");
};

LogicalResult DOTPrinter::printDOT(mlir::ModuleOp mod) {
// We support at most one function per module
Expand Down
28 changes: 27 additions & 1 deletion tools/export-dot/export-dot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "dynamatic/Support/DOTPrinter.h"
#include "dynamatic/Support/TimingModels.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand All @@ -23,13 +24,25 @@

using namespace llvm;
using namespace mlir;
using namespace dynamatic;

static cl::OptionCategory mainCategory("Application options");

static cl::opt<std::string> inputFileName(cl::Positional,
cl::desc("<input file>"),
cl::cat(mainCategory));

static cl::opt<std::string> timingDBFilepath(
"timing-models", cl::Optional,
cl::desc(
"Relative path to JSON-formatted file containing timing "
"models for dataflow components. The tool only tries to "
"read from this file if it is ran in legacy mode, where "
"timing annotations are given to all nodes in the graph. By default, "
"contains the relative path (from the project's top-level directory) "
"to the file defining the default timing models in Dynamatic++."),
cl::init("data/components.json"), cl::cat(mainCategory));

static cl::opt<bool> legacy(
"legacy", cl::Optional,
cl::desc("If true, the exported DOT file will be made compatible with "
Expand Down Expand Up @@ -82,6 +95,19 @@ int main(int argc, char **argv) {
if (!module)
return 1;

dynamatic::DOTPrinter printer(legacy, dotDebug);
if (legacy) {
// In legacy mode, read timing models for dataflow components from a
// JSON-formatted database
TimingDatabase timingDB(&context);
if (failed(TimingDatabase::readFromJSON(timingDBFilepath, timingDB))) {
llvm::errs() << "Failed to read timing database at \"" << timingDBFilepath
<< "\"\n";
return 1;
}
DOTPrinter printer(legacy, dotDebug, &timingDB);
return failed(printer.printDOT(*module));
}

DOTPrinter printer(legacy, dotDebug);
return failed(printer.printDOT(*module));
}

0 comments on commit b8a982d

Please sign in to comment.