Skip to content
This repository has been archived by the owner on Jan 30, 2025. It is now read-only.

Move bottom up fuser declaration to header file #23

Merged
merged 6 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ td_library(
"include/mlir-tcp/Dialect/IR/TcpEnums.td",
"include/mlir-tcp/Dialect/IR/TcpOps.td",
"include/mlir-tcp/Dialect/IR/TcpTypes.td",
"include/mlir-tcp/Dialect/IR/TcpOpsCruiseInternal.td",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unintentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. Fixed

],
includes = ["include"],
deps = [
Expand Down Expand Up @@ -134,13 +135,15 @@ cc_library(
name = "TcpDialectPasses",
srcs = [
"lib/Dialect/Transforms/FuseTcpOpsPass.cpp",
"lib/Dialect/Transforms/FusionPatterns.cpp",
"lib/Dialect/Transforms/IsolateGroupOpsPass.cpp",
"lib/Dialect/Transforms/PassDetail.h",
"lib/Dialect/Transforms/Passes.cpp",
"lib/Dialect/Transforms/VerifyTcpBackendContractPass.cpp",
],
hdrs = [
"include/mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h",
"include/mlir-tcp/Dialect/Transforms/FusionPatterns.h",
"include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h",
"include/mlir-tcp/Dialect/Transforms/Passes.h",
"include/mlir-tcp/Dialect/Transforms/VerifyTcpBackendContractPass.h",
Expand Down Expand Up @@ -192,12 +195,15 @@ cc_library(
"lib/Conversion/TorchToTcp/PopulatePatterns.h",
"lib/Conversion/TorchToTcp/TcpCustomOp.cpp",
"lib/Conversion/TorchToTcp/TorchToTcp.cpp",
"lib/Conversion/TorchToTcp/TorchToTcpCruiseInternal.cpp",
"lib/Conversion/TorchToTcp/CruiseInternalPatterns.cpp",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please revert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

"lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp",
"lib/Conversion/TorchToTcp/Utils.cpp",
"lib/Conversion/TorchToTcp/Utils.h",
],
hdrs = [
"include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h",
"include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCruiseInternal.h",
"include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h",
],
strip_include_prefix = "include",
Expand Down
33 changes: 33 additions & 0 deletions include/mlir-tcp/Dialect/Transforms/FusionPatterns.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#pragma once

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::tcp {

class GenericBottomUpFuser : public RewritePattern {
public:
using CanFuseFuncType = std::function<bool(Operation *, Operation *)>;
using PostProcessingFuncType = std::function<void(Operation *, PatternRewriter &rewriter)>;

GenericBottomUpFuser(MLIRContext *context, CanFuseFuncType canFuseCallback, PostProcessingFuncType postFuncCallback=nullptr)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
canFuse(canFuseCallback),
postFunc(postFuncCallback) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;

private:
CanFuseFuncType canFuse;
PostProcessingFuncType postFunc;
};
} // namespace mlir::tcp
79 changes: 1 addition & 78 deletions lib/Dialect/Transforms/FuseTcpOpsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,95 +8,18 @@
//===----------------------------------------------------------------------===//

#include "mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h"
#include "mlir-tcp/Dialect/IR/TcpDialect.h"
#include "mlir-tcp/Dialect/IR/TcpOps.h"
#include "mlir-tcp/Dialect/Transforms/FusionPatterns.h"
#include "mlir-tcp/Dialect/Transforms/Passes.h"

#include "./PassDetail.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

namespace mlir::tcp {

namespace {

class GenericBottomUpFuser : public RewritePattern {
public:
using CanFuseFuncType = std::function<bool(Operation *, Operation *)>;

GenericBottomUpFuser(MLIRContext *context, CanFuseFuncType cf)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
canFuse(cf) {}

LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Operation *use = op;
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp()) {
Operation *def = operand.getDefiningOp();
if (canFuse(def, use)) {

// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(def->getParentOp())) {
continue;
}

// We only support fusing def ops that have exactly one use, for now.
if (!def->hasOneUse()) {
continue;
}

// Fuse the def and use ops into a group.

// * If both the ops have the same parent region, they must be part
// of the top-level func. So, we need to create a new group.
// * The only other case is when the def op is part of the top-level
// func and the use is already inside a group.
if (def->getParentRegion() == use->getParentRegion()) {
auto groupOp =
rewriter.create<GroupOp>(use->getLoc(), use->getResultTypes());
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);
for (unsigned num = 0; num < use->getNumResults(); ++num) {
rewriter.replaceAllUsesWith(use->getResult(num),
groupOp->getResult(num));
}
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp =
rewriter.create<YieldOp>(use->getLoc(), use->getResults());
use->moveBefore(yieldOp);
operand.getDefiningOp()->moveBefore(use);
}
} else if (auto groupOp = dyn_cast<GroupOp>(use->getParentOp())) {
def->moveBefore(use);
} else {
llvm_unreachable("Unhandled case during fusion");
}
}
}
}
return success();
}

private:
CanFuseFuncType canFuse;
};

class TcpFuseElementwiseOpsPass
: public TcpFuseElementwiseOpsBase<TcpFuseElementwiseOpsPass> {
void runOnOperation() override {
Expand Down
80 changes: 80 additions & 0 deletions lib/Dialect/Transforms/FusionPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir-tcp/Dialect/Transforms/FusionPatterns.h"
#include "mlir-tcp/Dialect/IR/TcpDialect.h"
#include "mlir-tcp/Dialect/IR/TcpOps.h"

namespace mlir::tcp {
LogicalResult GenericBottomUpFuser::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
Operation *use = op;
bool isChanged = false;
for (auto operand : op->getOperands()) {
if (operand.getDefiningOp()) {
Operation *def = operand.getDefiningOp();
if (canFuse(def, use)) {

// Currently we are only fusing ops at the top-level.
// This is to avoid recursing inside a group and ending up with
// nested groups that contain the same ops.
// Since we are iterating bottom up in a block, we only need to check
// if the def op has a func parent.
//
// TODO: Remove this restriction to allow fusing in nested regions.
if (!isa<func::FuncOp>(def->getParentOp())) {
continue;
}

// We only support fusing def ops that have exactly one use, for now.
if (!def->hasOneUse()) {
continue;
}

// Fuse the def and use ops into a group.

// * If both the ops have the same parent region, they must be part
// of the top-level func. So, we need to create a new group.
// * The only other case is when the def op is part of the top-level
// func and the use is already inside a group.
isChanged = true;
if (def->getParentRegion() == use->getParentRegion()) {
auto groupOp =
rewriter.create<GroupOp>(use->getLoc(), use->getResultTypes());
if(postFunc) {
postFunc(groupOp, rewriter);
}
Block *groupBlock = new Block();
groupOp.getBody().push_back(groupBlock);
for (unsigned num = 0; num < use->getNumResults(); ++num) {
rewriter.replaceAllUsesWith(use->getResult(num),
groupOp->getResult(num));
}
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(groupBlock);
auto yieldOp =
rewriter.create<YieldOp>(use->getLoc(), use->getResults());
use->moveBefore(yieldOp);
operand.getDefiningOp()->moveBefore(use);
}
} else if (auto groupOp = dyn_cast<GroupOp>(use->getParentOp())) {
def->moveBefore(use);
} else {
llvm_unreachable("Unhandled case during fusion");
}
}
}
}
return isChanged ? success() : failure();
}
} // namespace mlir::tcp
Loading