diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp index 4ff656c18282..e8dbaae15f5d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/ResourceUsage.cpp @@ -394,6 +394,12 @@ class ValueResourceUsage : public AbstractResourceUsage { DFX::Resolution::REQUIRED); getState() ^= targetUsage.getState(); }) + .Case([&](IREE::Stream::AsyncBarrierOp op) { + auto &tiedUsage = solver.getElementFor( + *this, Position::forValue(op.getOperand(0)), + DFX::Resolution::REQUIRED); + getState() ^= tiedUsage.getState(); + }) .Case([&](IREE::Stream::AsyncTransferOp op) { removeAssumedBits(NOT_TRANSFER_WRITE); auto &sourceUsage = solver.getElementFor( @@ -716,6 +722,12 @@ class ValueResourceUsage : public AbstractResourceUsage { getState() ^= resultUsage.getState(); } }) + .Case([&](IREE::Stream::AsyncBarrierOp op) { + auto &resultUsage = solver.getElementFor( + *this, Position::forValue(op.getResult()), + DFX::Resolution::OPTIONAL); + getState() ^= resultUsage.getState(); + }) .Case([&](IREE::Stream::AsyncTransferOp op) { removeAssumedBits(NOT_TRANSFER_READ); auto &resultUsage = solver.getElementFor( diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir index 4fb2216faeec..df9e5480ef90 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/FlowToStream/test/tensor_ops.mlir @@ -138,6 +138,26 @@ util.func public @tensorSplat(%value: i8, %dim0: index) -> tensor { util.global private @device : !hal.device +// CHECK-LABEL: @tensorBarrierDispatch +// CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index) +util.func public @tensorBarrierDispatch(%input: tensor, %dim0: index) -> tensor { + %c0 = arith.constant 0 : index + %barrier = flow.tensor.barrier %input : tensor{%dim0} on #hal.device.affinity<@device> + %0 = flow.dispatch @ex::@entry[%c0](%barrier) : (tensor{%dim0}) -> tensor{%dim0} + + // CHECK: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[BARRIER:.+]] = stream.async.barrier %[[INPUT]] : !stream.resource<*>{%[[DIM0]]} -> !stream.resource<*> + // CHECK: %[[C0_2:.+]] = arith.constant 0 : index + // CHECK: %[[SIZE:.+]] = stream.tensor.sizeof on(#hal.device.affinity<@device>) tensor{%arg2} : index + // CHECK: %[[DISP:.+]] = stream.async.dispatch on(#hal.device.affinity<@device>) @ex::@entry[%[[C0]]](%[[BARRIER]][%[[C0_2]] to %[[DIM0]] for %[[DIM0]]]) + // CHECK: util.return %[[DISP]], %[[SIZE]] + util.return %0 : tensor +} + +// ----- + +util.global private @device : !hal.device + // CHECK-LABEL: @tensorTransfer // CHECK-SAME: (%[[INPUT:.+]]: !stream.resource<*>, %[[INPUT_SIZE:.+]]: index, %[[DIM0:.+]]: index) util.func public @tensorTransfer(%input: tensor, %dim0: index) -> tensor { diff --git a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp index 45122452d64b..051b6a151a35 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Conversion/PatternUtils.cpp @@ -64,7 +64,9 @@ ConvertedTensor transferTensorOperands( Value resource = convertedOperand[0]; Value resourceSize = convertedOperand[1]; auto affinityAttr = affinityAnalysis->lookupResourceAffinity(originalOperand); - if (affinityAttr != requiredAffinityAttr) { + bool isBarrier = resource.getDefiningOp() && + isa(resource.getDefiningOp()); + if (affinityAttr != requiredAffinityAttr && !isBarrier) { resource = builder.create( loc, resource.getType(), resource, resourceSize, resourceSize, affinityAttr, requiredAffinityAttr);