This repository has been archived by the owner on Jan 30, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
Move bottom up fuser declaration to header file #23
Merged
Merged
Changes from 1 commit
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ td_library( | |
"include/mlir-tcp/Dialect/IR/TcpEnums.td", | ||
"include/mlir-tcp/Dialect/IR/TcpOps.td", | ||
"include/mlir-tcp/Dialect/IR/TcpTypes.td", | ||
"include/mlir-tcp/Dialect/IR/TcpOpsCruiseInternal.td", | ||
], | ||
includes = ["include"], | ||
deps = [ | ||
|
@@ -134,13 +135,15 @@ cc_library( | |
name = "TcpDialectPasses", | ||
srcs = [ | ||
"lib/Dialect/Transforms/FuseTcpOpsPass.cpp", | ||
"lib/Dialect/Transforms/FusionPatterns.cpp", | ||
"lib/Dialect/Transforms/IsolateGroupOpsPass.cpp", | ||
"lib/Dialect/Transforms/PassDetail.h", | ||
"lib/Dialect/Transforms/Passes.cpp", | ||
"lib/Dialect/Transforms/VerifyTcpBackendContractPass.cpp", | ||
], | ||
hdrs = [ | ||
"include/mlir-tcp/Dialect/Transforms/FuseTcpOpsPass.h", | ||
"include/mlir-tcp/Dialect/Transforms/FusionPatterns.h", | ||
"include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h", | ||
"include/mlir-tcp/Dialect/Transforms/Passes.h", | ||
"include/mlir-tcp/Dialect/Transforms/VerifyTcpBackendContractPass.h", | ||
|
@@ -192,12 +195,15 @@ cc_library( | |
"lib/Conversion/TorchToTcp/PopulatePatterns.h", | ||
"lib/Conversion/TorchToTcp/TcpCustomOp.cpp", | ||
"lib/Conversion/TorchToTcp/TorchToTcp.cpp", | ||
"lib/Conversion/TorchToTcp/TorchToTcpCruiseInternal.cpp", | ||
"lib/Conversion/TorchToTcp/CruiseInternalPatterns.cpp", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please revert There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
"lib/Conversion/TorchToTcp/TorchToTcpCustomOp.cpp", | ||
"lib/Conversion/TorchToTcp/Utils.cpp", | ||
"lib/Conversion/TorchToTcp/Utils.h", | ||
], | ||
hdrs = [ | ||
"include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h", | ||
"include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCruiseInternal.h", | ||
"include/mlir-tcp/Conversion/TorchToTcp/TorchToTcpCustomOp.h", | ||
], | ||
strip_include_prefix = "include", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#pragma once | ||
|
||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::tcp { | ||
|
||
class GenericBottomUpFuser : public RewritePattern { | ||
public: | ||
using CanFuseFuncType = std::function<bool(Operation *, Operation *)>; | ||
using PostProcessingFuncType = std::function<void(Operation *, PatternRewriter &rewriter)>; | ||
|
||
GenericBottomUpFuser(MLIRContext *context, CanFuseFuncType canFuseCallback, PostProcessingFuncType postFuncCallback=nullptr) | ||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), | ||
canFuse(canFuseCallback), | ||
postFunc(postFuncCallback) {} | ||
|
||
LogicalResult matchAndRewrite(Operation *op, | ||
PatternRewriter &rewriter) const override; | ||
|
||
private: | ||
CanFuseFuncType canFuse; | ||
PostProcessingFuncType postFunc; | ||
}; | ||
} // namespace mlir::tcp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
//===------------------------------------------------------------*- C++ -*-===// | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// Also available under a BSD-style license. See LICENSE. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/OpDefinition.h" | ||
#include "mlir-tcp/Dialect/Transforms/FusionPatterns.h" | ||
#include "mlir-tcp/Dialect/IR/TcpDialect.h" | ||
#include "mlir-tcp/Dialect/IR/TcpOps.h" | ||
|
||
namespace mlir::tcp { | ||
LogicalResult GenericBottomUpFuser::matchAndRewrite(Operation *op, | ||
PatternRewriter &rewriter) const { | ||
Operation *use = op; | ||
bool isChanged = false; | ||
for (auto operand : op->getOperands()) { | ||
if (operand.getDefiningOp()) { | ||
Operation *def = operand.getDefiningOp(); | ||
if (canFuse(def, use)) { | ||
|
||
// Currently we are only fusing ops at the top-level. | ||
// This is to avoid recursing inside a group and ending up with | ||
// nested groups that contain the same ops. | ||
// Since we are iterating bottom up in a block, we only need to check | ||
// if the def op has a func parent. | ||
// | ||
// TODO: Remove this restriction to allow fusing in nested regions. | ||
if (!isa<func::FuncOp>(def->getParentOp())) { | ||
continue; | ||
} | ||
|
||
// We only support fusing def ops that have exactly one use, for now. | ||
if (!def->hasOneUse()) { | ||
continue; | ||
} | ||
|
||
// Fuse the def and use ops into a group. | ||
|
||
// * If both the ops have the same parent region, they must be part | ||
// of the top-level func. So, we need to create a new group. | ||
// * The only other case is when the def op is part of the top-level | ||
// func and the use is already inside a group. | ||
isChanged = true; | ||
if (def->getParentRegion() == use->getParentRegion()) { | ||
auto groupOp = | ||
rewriter.create<GroupOp>(use->getLoc(), use->getResultTypes()); | ||
if(postFunc) { | ||
postFunc(groupOp, rewriter); | ||
} | ||
Block *groupBlock = new Block(); | ||
groupOp.getBody().push_back(groupBlock); | ||
for (unsigned num = 0; num < use->getNumResults(); ++num) { | ||
rewriter.replaceAllUsesWith(use->getResult(num), | ||
groupOp->getResult(num)); | ||
} | ||
{ | ||
OpBuilder::InsertionGuard guard(rewriter); | ||
rewriter.setInsertionPointToStart(groupBlock); | ||
auto yieldOp = | ||
rewriter.create<YieldOp>(use->getLoc(), use->getResults()); | ||
use->moveBefore(yieldOp); | ||
operand.getDefiningOp()->moveBefore(use); | ||
} | ||
} else if (auto groupOp = dyn_cast<GroupOp>(use->getParentOp())) { | ||
def->moveBefore(use); | ||
} else { | ||
llvm_unreachable("Unhandled case during fusion"); | ||
} | ||
} | ||
} | ||
} | ||
return isChanged ? success() : failure(); | ||
} | ||
} // namespace mlir::tcp |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unintentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this. Fixed