Skip to content

Commit

Permalink
ci fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IanWood1 committed May 14, 2024
1 parent 35163b4 commit 4f1e782
Showing 1 changed file with 47 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down Expand Up @@ -149,58 +150,58 @@ struct FoldSuccessiveTensorInsertSliceOps
// cannot be fused because it there is no producer-consumer
// relationship between the two generics. This is because the indexing
// is not affine (index values come from a tensor).
struct GatherFusionPattern : public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp consumerOp,
struct GatherFusionPattern : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
auto extractOps = consumerOp.getOps<tensor::ExtractOp>();
if (extractOps.empty()) {
return failure();
// Check if extractOp is inside a generic op
auto consumerOp =
dyn_cast_or_null<linalg::GenericOp>(extractOp->getParentOp());
if (!consumerOp) {
return rewriter.notifyMatchFailure(
extractOp, "expected extract op to be inside a generic op");
}

for (tensor::ExtractOp extractOp : extractOps) {
auto producerOp = dyn_cast_or_null<linalg::GenericOp>(
extractOp.getOperand(0).getDefiningOp());
if (!producerOp) {
return rewriter.notifyMatchFailure(
consumerOp, "expected extract operand to be a generic op");
}

// Check if the producerOp is fusible
if (producerOp.getNumDpsInputs() != 1 || !isElementwise(producerOp) ||
!isDequantizationLikeOp(producerOp)) {
return rewriter.notifyMatchFailure(producerOp,
"producer op is not fusible");
}
auto producerOp = dyn_cast_or_null<linalg::GenericOp>(
extractOp.getOperand(0).getDefiningOp());
if (!producerOp) {
return rewriter.notifyMatchFailure(
consumerOp, "expected extract operand to be a generic op");
}

// fuse by performing the dequantization-like operation after
// tensor.extract
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(extractOp);

// Create a new extract op that extracts from the original tensor
// (after the original extract). Clone the producerOp's body into the
// consumerOp, inline the cloned block (erases the block) after the new
// extract, and clean up.
auto newExtractOp = rewriter.create<tensor::ExtractOp>(
extractOp->getLoc(), producerOp->getOperand(0),
extractOp.getIndices());
rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(),
consumerOp.getRegion().begin());
Block &clonedBlock = consumerOp.getRegion().front();
auto producerTermOp = clonedBlock.getTerminator();

rewriter.inlineBlockBefore(
&clonedBlock, extractOp->getNextNode(),
{newExtractOp.getResult(), newExtractOp->getResult(0)});

// Replace the the all references to the original extract result with the
// result from the inlined producerOp.
extractOp->getResult(0).replaceAllUsesWith(producerTermOp->getOperand(0));
rewriter.eraseOp(producerTermOp);
rewriter.eraseOp(extractOp);
// Check if the producerOp is fusible
if (producerOp.getNumDpsInputs() != 1 || !isElementwise(producerOp) ||
!isDequantizationLikeOp(producerOp)) {
return rewriter.notifyMatchFailure(producerOp,
"producer op is not fusible");
}

// fuse by performing the dequantization-like operation after
// tensor.extract
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(extractOp);

// Create a new extract op that extracts from the original tensor
// (after the original extract). Clone the producerOp's body into the
// consumerOp, inline the cloned block (erases the block) after the new
// extract, and clean up.
auto newExtractOp = rewriter.create<tensor::ExtractOp>(
extractOp->getLoc(), producerOp->getOperand(0), extractOp.getIndices());
rewriter.cloneRegionBefore(producerOp.getRegion(), consumerOp.getRegion(),
consumerOp.getRegion().begin());
Block &clonedBlock = consumerOp.getRegion().front();
auto producerTermOp = clonedBlock.getTerminator();

rewriter.inlineBlockBefore(
&clonedBlock, extractOp->getNextNode(),
{newExtractOp.getResult(), newExtractOp->getResult(0)});

// Replace the the all references to the original extract result with the
// result from the inlined producerOp.
extractOp->getResult(0).replaceAllUsesWith(producerTermOp->getOperand(0));
rewriter.eraseOp(producerTermOp);
rewriter.eraseOp(extractOp);

return success();
}
};
Expand Down

0 comments on commit 4f1e782

Please sign in to comment.