Skip to content

Commit

Permalink
[mlir][Interfaces] DestinationStyleOpInterface: Rename `hasTensor/B…
Browse files Browse the repository at this point in the history
…ufferSemantics` (#77574)

Rename interface functions as follows:
* `hasTensorSemantics` -> `hasPureTensorSemantics`
* `hasBufferSemantics` -> `hasPureBufferSemantics`

These two functions return "true" if the op has tensor/buffer operands
but not buffer/tensor operands.

Also drop the "ranked" part from the interface, i.e., do not distinguish
between ranked/unranked types.

The new function names describe the functions more accurately. They also
align their semantics with the notion of "tensor semantics" with the
bufferization framework. (An op is supposed to be bufferized if it has
tensor operands, and we don't care if it also has memref operands.)

This change is in preparation of #75273, which adds
`BufferizableOpInterface::hasTensorSemantics`. By renaming the functions
in the `DestinationStyleOpInterface`, we can avoid name clashes between
the two interfaces.
  • Loading branch information
matthias-springer authored Jan 12, 2024
1 parent 1aacdfe commit 0a8e3dd
Show file tree
Hide file tree
Showing 25 changed files with 85 additions and 83 deletions.
50 changes: 25 additions & 25 deletions mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
as initial tensor values for the results of the operation or the init
buffers to which the results of the op will be written.

Init operands must be ranked tensors or ranked memrefs. Input operands can
have any type. All non-init operands are DPS inputs.
Init operands must be tensors or memrefs. Input operands can have any type.
All non-init operands are DPS inputs.

The init operands of this op are specified by the MutableOperandRange that
the `getDpsInitsMutable` interface methods returns. This implies that the
init operands must be a consecutive range of operands.

If the op has "tensor semantics", then the input operands are either ranked
tensors or other non-tensor/memref types ("scalars"). The init operands are
ranked tensors and every tensor init is tied to a corresponding tensor
OpResult in a 1-to-1 fashion. The i-th init tensor is tied to the i-th
OpResult. The op may not have any additional OpResults. Init operands and
their tied OpResults have the same type. Dynamic dimension sizes also match
at runtime.
Each tensor init operand is tied to a corresponding tensor OpResult in a
1-to-1 fashion. The i-th init tensor is tied to the i-th OpResult. The op
may not have any additional OpResults. Init operands and their tied
OpResults have the same type. Dynamic dimension sizes also match at runtime.

If the op has "buffer semantics", then the input operands are either ranked
memrefs or other non-tensor/memref types ("scalar" types). Furthermore, the
init operands are ranked memrefs and the op has no results.
Note: This implies that a destination style op without any tensor inits must
not have any OpResults.

An op has "pure tensor semantics" if it has at least one tensor operand and
no buffer (memref) operands. It has "pure buffer semantics" if it has at
least one buffer (memref) operand and no tensor operands.

Destination-passing style abstraction makes certain transformations easier.
For example, tiling implementation can extract/insert slices from/into the
Expand Down Expand Up @@ -148,7 +148,8 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
/// neither a MemRef nor a tensor value.
bool isScalar(::mlir::OpOperand *opOperand) {
assert(opOperand->getOwner() == $_op && "invalid operand");
return !::llvm::isa<MemRefType, TensorType>(opOperand->get().getType());
return !::llvm::isa<BaseMemRefType, TensorType>(
opOperand->get().getType());
}

/// Return the OpResult that is tied to the given OpOperand.
Expand All @@ -169,37 +170,36 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
return $_op.getDpsInitOperand(opResult.getResultNumber());
}

/// Return whether the op has buffer semantics. That is the case if the op
/// has no ranked tensor operands and at least one memref operand.
bool hasBufferSemantics() {
/// Return whether the op has pure buffer semantics. That is the case if the
/// op has no tensor operands and at least one memref operand.
bool hasPureBufferSemantics() {
// No tensors.
auto isTensor = [](Value v){
return ::llvm::isa<::mlir::RankedTensorType>(v.getType());
return ::llvm::isa<::mlir::TensorType>(v.getType());
};
if (::llvm::any_of($_op->getOperands(), isTensor))
return false;
// At least one memref.
auto isMemref = [](Value v){
return ::llvm::isa<::mlir::MemRefType>(v.getType());
return ::llvm::isa<::mlir::BaseMemRefType>(v.getType());
};
return llvm::any_of($_op->getOperands(), isMemref);
}

/// Return whether the op has tensor semantics. That is the case if the op
/// has no memref operands and at least one ranked tensor operand.
bool hasTensorSemantics() {
/// Return whether the op has pure tensor semantics. That is the case if the
/// op has no memref operands and at least one tensor operand.
bool hasPureTensorSemantics() {
// No memrefs.
auto isMemref = [](Value v){
return ::llvm::isa<::mlir::MemRefType>(v.getType());
return ::llvm::isa<::mlir::BaseMemRefType>(v.getType());
};
if (::llvm::any_of($_op->getOperands(), isMemref))
return false;
// At least one tensor.
auto isTensor = [](Value v){
return ::llvm::isa<::mlir::RankedTensorType>(v.getType());
return ::llvm::isa<::mlir::TensorType>(v.getType());
};
return llvm::any_of($_op->getOperands(), isTensor);
}
return llvm::any_of($_op->getOperands(), isTensor); }
}];

let verify = [{ return detail::verifyDestinationStyleOpInterface($_op); }];
Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ struct EraseSelfCopy : OpRewritePattern<CopyOp> {
PatternRewriter &rewriter) const override {
if (copyOp.getInputs() != copyOp.getOutputs())
return rewriter.notifyMatchFailure(copyOp, "not a self copy");
if (copyOp.hasBufferSemantics())
if (copyOp.hasPureBufferSemantics())
rewriter.eraseOp(copyOp);
else
rewriter.replaceOp(copyOp, copyOp.getInputs());
Expand Down Expand Up @@ -1112,7 +1112,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
return failure();

// In the buffer case, we need to check exact buffer equality.
if (genericOp.hasBufferSemantics()) {
if (genericOp.hasPureBufferSemantics()) {
if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
genericOp.getDpsInputOperand(0)->get() ==
genericOp.getDpsInitOperand(0)->get()) {
Expand All @@ -1123,7 +1123,7 @@ struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
}

// Mixed semantics is not supported yet.
if (!genericOp.hasTensorSemantics())
if (!genericOp.hasPureTensorSemantics())
return failure();

// Get the argument number of the returned values. That is the operand
Expand Down Expand Up @@ -2257,7 +2257,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {

LogicalResult matchAndRewrite(LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
if (!linalgOp.hasTensorSemantics())
if (!linalgOp.hasPureTensorSemantics())
return failure();

// Maps must be projected permutations.
Expand Down Expand Up @@ -2376,7 +2376,7 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));

SmallVector<Type, 4> resultTypes;
if (hasTensorSemantics())
if (hasPureTensorSemantics())
resultTypes.push_back(tiledOperands[1].getType());
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct BubbleUpExtractSliceOpPattern
"expected single output of linalg op");
}

if (!linalgOp.hasTensorSemantics()) {
if (!linalgOp.hasPureTensorSemantics()) {
return rewriter.notifyMatchFailure(sliceOp,
"expected tensor of linalg op");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
rewriter.setInsertionPoint(op);

// Nothing to do. This op is already bufferized.
if (op.hasBufferSemantics())
if (op.hasPureBufferSemantics())
return success();

// Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need
// basis.
if (!op.hasTensorSemantics())
return op->emitError() << "op does not have tensor semantics";
if (!op.hasPureTensorSemantics())
return op->emitError() << "op does not have pure tensor semantics";

// New input operands for the cloned op.
SmallVector<Value> newInputBuffers;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Mixed and buffer sematics aren't supported.
if (!genericOp.hasTensorSemantics())
if (!genericOp.hasPureTensorSemantics())
return failure();

// Only support ops generating one output for now.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
// TODO: this could be generalized to handle `linalg.generic` with buffer
// operands too but requires allocation for intermediates. Punt on this for
// now.
if (!genericOp.hasTensorSemantics()) {
if (!genericOp.hasPureTensorSemantics()) {
return rewriter.notifyMatchFailure(
genericOp, "only operations with tensor semantics are handled");
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
if (!genericOp.hasPureTensorSemantics())
return failure();
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
return failure();
Expand Down
16 changes: 8 additions & 8 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
// Consumer can have mixed semantics, just check operand itself has tensor
// type. Producer must have full tensor semantics to avoid potential
// aliasing between producer and consumer memrefs.
if (!producer.hasTensorSemantics() ||
if (!producer.hasPureTensorSemantics() ||
!isa<RankedTensorType>(fusedOperand->get().getType()))
return false;

Expand Down Expand Up @@ -530,7 +530,7 @@ static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp,
// permutations.
// - The fused tensor is not a scalar.
// - All the loops are parallel loops.
return genericOp.hasTensorSemantics() &&
return genericOp.hasPureTensorSemantics() &&
llvm::all_of(genericOp.getIndexingMaps().getValue(),
[](Attribute attr) {
return cast<AffineMapAttr>(attr)
Expand Down Expand Up @@ -1124,7 +1124,7 @@ static SmallVector<ReassociationIndices>
getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand,
ArrayRef<ReassociationIndices> reassociation) {
// Some basic checks for this fusion to be valid.
if (!genericOp.hasTensorSemantics() || genericOp.getNumDpsInits() != 1)
if (!genericOp.hasPureTensorSemantics() || genericOp.getNumDpsInits() != 1)
return {};

if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) {
Expand Down Expand Up @@ -1476,7 +1476,7 @@ Operation *createCollapsedOp(LinalgType op,
outputOperands.push_back(newOutput);
// If the op has "buffer semantics", then the init operands are ranked
// memrefs and the op has no results.
if (!op.hasBufferSemantics())
if (!op.hasPureBufferSemantics())
resultTypes.push_back(newOutput.getType());
}

Expand Down Expand Up @@ -1521,8 +1521,8 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
}))
return failure();

bool hasBufferSemantics = op.hasBufferSemantics();
if (hasBufferSemantics &&
bool hasPureBufferSemantics = op.hasPureBufferSemantics();
if (hasPureBufferSemantics &&
!llvm::all_of(op->getOperands(), [&](Value operand) -> bool {
MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType());
if (!memRefToCollapse)
Expand Down Expand Up @@ -1705,7 +1705,7 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
if (!genericOp.hasPureTensorSemantics())
return failure();
for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
Operation *def = opOperand->get().getDefiningOp();
Expand Down Expand Up @@ -1857,7 +1857,7 @@ struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> {

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
if (!genericOp.hasPureTensorSemantics())
return failure();
bool fillFound = false;
Block &payload = genericOp.getRegion().front();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
dedupedOutpts;
// If the op doesn't have tensor semantics or outputs should not be removed,
// keep all the outputs as preserved.
if (!genericOp.hasTensorSemantics() || !removeOutputs) {
if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
origToNewPos[en.index()] = newOutputOperands.size();
newOutputOperands.push_back(en.value().get());
Expand Down Expand Up @@ -317,7 +317,7 @@ struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override {

// If the op doesnt have tensor semantics, preserve the outputs as is.
if (!genericOp.hasTensorSemantics())
if (!genericOp.hasPureTensorSemantics())
return failure();

bool hasRemovedCycles = false;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
ValueRange outputs = linalgOp.getDpsInits();
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
SmallVector<utils::IteratorType> iterators = linalgOp.getIteratorTypesArray();
SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
SmallVector<Type> resultTypes = linalgOp.hasPureTensorSemantics()
? TypeRange(ValueRange(outputs))
: TypeRange{};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
if (!genericOp.hasPureTensorSemantics())
return failure();

SmallVector<size_t> scalarOperands;
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ template <typename LoadOpTy, typename StoreOpTy>
static void emitScalarImplementation(OpBuilder &b, Location loc,
ArrayRef<Value> allIvs,
LinalgOp linalgOp) {
assert(linalgOp.hasBufferSemantics() &&
assert(linalgOp.hasPureBufferSemantics() &&
"expected linalg op with buffer semantics");
SmallVector<Value> indexedValues;
indexedValues.reserve(linalgOp->getNumOperands());
Expand Down Expand Up @@ -218,7 +218,7 @@ static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter,

// The flattened loopToOperandRangesMaps is expected to be an invertible
// permutation map (which is asserted in the inverse calculation).
assert(linalgOp.hasBufferSemantics() &&
assert(linalgOp.hasPureBufferSemantics() &&
"expected linalg op with buffer semantics");

auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
Expand Down Expand Up @@ -264,7 +264,7 @@ class LinalgRewritePattern : public RewritePattern {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!isa<LinalgOp>(op) || !linalgOp.hasBufferSemantics()) {
if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) {
return rewriter.notifyMatchFailure(
op, "expected linalg op with buffer semantics");
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
Location loc = operation->getLoc();
auto linalgOp = dyn_cast<LinalgOp>(operation);
// Exit out on the memref version of this operation.
if (!linalgOp || !linalgOp.hasTensorSemantics())
if (!linalgOp || !linalgOp.hasPureTensorSemantics())
return failure();

auto result = operation->getResult(0);
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
}

// TODO: there are cases where we may still want to pad to larger sizes.
if (!opToPad.hasTensorSemantics())
if (!opToPad.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(opToPad,
"expected operation on tensors");

Expand Down Expand Up @@ -265,7 +265,7 @@ mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp,
assert(options.copyBackOp == LinalgPaddingOptions::CopyBackOp::None &&
"invalid options");

if (!linalgOp.hasTensorSemantics())
if (!linalgOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
linalgOp, "only applies to Linalg ops with tensor semantics");

Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ struct LinalgOpInstancePromotionOptions {
LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
LinalgOp linalgOp, const LinalgPromotionOptions &options)
: subViews(), alignment(options.alignment) {
assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
assert(linalgOp.hasPureBufferSemantics() &&
"revisit usage of shaped operand");
auto vUseFullTileBuffers =
options.useFullTileBuffers.value_or(llvm::SmallBitVector());
vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
Expand Down Expand Up @@ -346,7 +347,8 @@ promoteSubViews(ImplicitLocOpBuilder &b,
static FailureOr<LinalgOp>
promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
LinalgOpInstancePromotionOptions options, DataLayout &layout) {
assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
assert(op.hasPureBufferSemantics() &&
"expected linalg op with buffer semantics");

// 1. Promote the specified views and use them in the new op.
auto promotedBuffersAndViews = promoteSubViews(b, options, layout);
Expand Down Expand Up @@ -400,7 +402,7 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
LinalgPromotionOptions options) {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
// Transformation applies to buffers only.
if (!linalgOp || !linalgOp.hasBufferSemantics())
if (!linalgOp || !linalgOp.hasPureBufferSemantics())
return failure();
// Check that at least one of the requested operands is indeed a subview.
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ struct LinalgOpTilingInterface
Location loc,
ValueRange ivs) const {
auto linalgOp = cast<LinalgOp>(op);
if (!linalgOp.hasBufferSemantics())
if (!linalgOp.hasPureBufferSemantics())
return op->emitOpError("expected operation to have buffer semantics");

SmallVector<Value> indexedValues;
Expand Down Expand Up @@ -256,7 +256,7 @@ struct LinalgOpPartialReductionInterface
auto linalgOp = cast<LinalgOp>(op);
OpBuilder::InsertionGuard guard(b);

if (linalgOp.hasBufferSemantics())
if (linalgOp.hasPureBufferSemantics())
return op->emitOpError("expected operation to have tensor semantics");
// Insert the new parallel dimension based on the index of the reduction
// loops. This could be controlled by user for more flexibility.
Expand Down
Loading

0 comments on commit 0a8e3dd

Please sign in to comment.