-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Implement emulation of static indexing subbyte type vector stores #115922
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: lialan (lialan) ChangesThis patch enables unaligned, statically indexed storing of vectors with subbyte element types. To ensure atomicity, boundary bytes in the front and at the back are being treated separately with atomic stores to avoid thread contention, while those bytes (if any) free from race conditions will use non-atomic stores. To illustrate the mechanism, consider the example of storing
This patch expands the store into two atomic partial stores (for the first and the third byte), and a non-atomic single full byte store (for the second byte). Patch is 21.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115922.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7578aadee23a6e..031eb2d9c443b8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -143,19 +143,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
/// emitting `vector.extract_strided_slice`.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
- VectorType extractType, Value source,
- int64_t frontOffset,
+ Value source, int64_t frontOffset,
int64_t subvecSize) {
- auto vectorType = cast<VectorType>(source.getType());
- assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
- "expected 1-D source and destination types");
- (void)vectorType;
+ auto vectorType = llvm::cast<VectorType>(source.getType());
+ assert(vectorType.getRank() == 1 && "expected 1-D source types");
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto strides = rewriter.getI64ArrayAttr({1});
+
+ auto resultVectorType =
+ VectorType::get({subvecSize}, vectorType.getElementType());
return rewriter
- .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
- sizes, strides)
+ .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+ offsets, sizes, strides)
->getResult(0);
}
@@ -164,12 +164,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
/// `vector.insert_strided_slice`.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
Value src, Value dest, int64_t offset) {
- auto srcType = cast<VectorType>(src.getType());
- auto destType = cast<VectorType>(dest.getType());
+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
"expected source and dest to be vector type");
- (void)srcType;
- (void)destType;
auto offsets = rewriter.getI64ArrayAttr({offset});
auto strides = rewriter.getI64ArrayAttr({1});
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -236,6 +234,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
newLoad);
}
+static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
+ Value memref, Value index, Value value) {
+ auto originType = dyn_cast<VectorType>(value.getType());
+ auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
+ auto scale = memrefElemType.getIntOrFloatBitWidth() /
+ originType.getElementType().getIntOrFloatBitWidth();
+ auto storeType =
+ VectorType::get({originType.getNumElements() / scale}, memrefElemType);
+ auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
+ rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
+}
+
+/// atomically store a subbyte-sized value to memory, with a mask.
+static Value atomicStore(OpBuilder &rewriter, Location loc,
+ Value emulatedMemref, Value emulatedIndex,
+ TypedValue<VectorType> value, Value mask,
+ int64_t scale) {
+ auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
+ loc, emulatedMemref, ValueRange{emulatedIndex});
+ OpBuilder builder =
+ OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
+ Value origValue = atomicOp.getCurrentValue();
+
+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
+ auto oneVectorType = VectorType::get({1}, origValue.getType());
+ auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+ ValueRange{origValue});
+ auto vectorBitCast =
+ builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+ auto select =
+ builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+ auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+ auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+ builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+ return atomicOp;
+}
+
+// Extract a slice of a vector, and insert it into a byte vector.
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+ Location loc, TypedValue<VectorType> vector,
+ int64_t sliceOffset, int64_t sliceNumElements,
+ int64_t byteOffset) {
+ auto vectorElementType = vector.getType().getElementType();
+ assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+ "vector element must be a valid sub-byte type");
+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+ auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+ loc, VectorType::get({scale}, vectorElementType),
+ rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+ auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+ sliceOffset, sliceNumElements);
+ auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
+ emptyByteVector, byteOffset);
+ return inserted;
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -251,7 +306,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
- Type oldElementType = op.getValueToStore().getType().getElementType();
+ auto valueToStore = op.getValueToStore();
+ Type oldElementType = valueToStore.getType().getElementType();
Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -275,15 +331,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
// vector<4xi8>
- auto origElements = op.getValueToStore().getType().getNumElements();
- if (origElements % scale != 0)
- return failure();
+ auto origElements = valueToStore.getType().getNumElements();
+ bool isUnalignedEmulation = origElements % scale != 0;
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
+ memref::LinearizedMemRefInfo linearizedInfo;
+ std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
@@ -291,14 +347,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(adaptor.getIndices()));
- auto numElements = origElements / scale;
- auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, newElementType),
- op.getValueToStore());
+ auto foldedIntraVectorOffset =
+ isUnalignedEmulation
+ ? getConstantIntValue(linearizedInfo.intraDataOffset)
+ : 0;
+
+ if (!foldedIntraVectorOffset) {
+ // unimplemented case for dynamic front padding size
+ return failure();
+ }
+
+ Value emulatedMemref = adaptor.getBase();
+ // the index into the target memref we are storing to
+ Value currentDestIndex =
+ getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+ auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
+ // the index into the source vector we are currently processing
+ auto currentSourceIndex = 0;
+
+ // 1. atomic store for the first byte
+ auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
+ if (frontAtomicStoreElem != 0) {
+ auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
+ if (*foldedIntraVectorOffset + origElements < scale) {
+ std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
+ origElements, true);
+ frontAtomicStoreElem = origElements;
+ } else {
+ std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
+ *foldedIntraVectorOffset, true);
+ }
+ auto frontMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
+
+ currentSourceIndex = scale - (*foldedIntraVectorOffset);
+ auto value = extractSliceIntoByte(
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
+ frontAtomicStoreElem, *foldedIntraVectorOffset);
+
+ atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<TypedValue<VectorType>>(value), frontMask.getResult(),
+ scale);
+
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+ }
+
+ if (currentSourceIndex >= origElements) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ // 2. non-atomic store
+ int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
+ int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
+ if (nonAtomicStoreSize != 0) {
+ auto nonAtomicStorePart = staticallyExtractSubvector(
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+ currentSourceIndex, numNonAtomicElements);
+
+ nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ nonAtomicStorePart);
+
+ currentSourceIndex += numNonAtomicElements;
+ currentDestIndex = rewriter.create<arith::AddIOp>(
+ loc, rewriter.getIndexType(), currentDestIndex,
+ rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
+ }
+
+ // 3. atomic store for the last byte
+ auto remainingElements = origElements - currentSourceIndex;
+ if (remainingElements != 0) {
+ auto atomicStorePart = extractSliceIntoByte(
+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+ currentSourceIndex, remainingElements, 0);
+
+ // back mask
+ auto maskValues = llvm::SmallVector<bool>(scale, 0);
+ std::fill_n(maskValues.begin(), remainingElements, 1);
+ auto backMask = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(atomicMaskType, maskValues));
+
+ atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+ cast<TypedValue<VectorType>>(atomicStorePart),
+ backMask.getResult(), scale);
+ }
- rewriter.replaceOpWithNewOp<vector::StoreOp>(
- op, bitCast.getResult(), adaptor.getBase(),
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ rewriter.eraseOp(op);
return success();
}
};
@@ -496,9 +632,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
return success();
@@ -652,9 +787,8 @@ struct ConvertVectorMaskedLoad final
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
@@ -732,9 +866,8 @@ struct ConvertVectorTransferRead final
linearizedInfo.intraDataOffset,
origElements);
} else if (isUnalignedEmulation) {
- result =
- staticallyExtractSubvector(rewriter, loc, op.getType(), result,
- *foldedIntraVectorOffset, origElements);
+ result = staticallyExtractSubvector(
+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..3b1f5b2e160fe0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,138 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
// CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
// CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
// CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_store_i2_const(%arg0: vector<3xi2>) {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+ return
+}
+
+// in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
+// CHECK: func @vector_store_i2_const(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// atomic store of the first byte
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// atomic store of the second byte
+// CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
+
+// -----
+
+func.func @vector_store_i8_2(%arg0: vector<7xi2>) {
+ %0 = memref.alloc() : memref<3x7xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+ return
+}
+
+// in this example, emit 2 atomic stores and 1 non-atomic store
+
+// CHECK: func @vector_store_i8_2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// first atomic store
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// non atomic store part
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// second atomic store
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+ %0 = memref.alloc() : memref<4x1xi2>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+ return
+}
+
+// in this example, only emit 1 ato...
[truncated]
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would happen if the start of the store is aligned? Would it still generate atomic stores?
I havent looked into the details yet, but we only need an atomic store if we cannot guarantee that there are no competing stores. This information we canno thave at this level, but a caller might have this information. I think it might be better to allow for a caller to indicate that the atomic stores are not needed. The default can be that you do the atomic stores.
bfdbed2
to
edfe3d4
Compare
It will be operation as usual, the atomic and the extra handling only happen when needed.
Feel like in such a non-competing case we can do strength reduction and use masked store instead of atomic for the unaligned parts (the beginning and the end byte if it is unaligned). |
You cant always tell from the program analysis that the atomic isnt needed. This depends on the caller's context. For example if the caller knows that, even dynamically, the stores never overlap. It is hard to recover that information at this late stage. So I think it would be good to have an option where atomics are avoided (cause it is expensive) and let caller opt-in/out? |
edfe3d4
to
9b81a3f
Compare
aec91d0
to
b46f01e
Compare
b46f01e
to
5069134
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @lialan for the patch. Here is the first round of review comments. I haven't finished the review of the three steps yet, but the idea looks good to me. I mostly want to clarify my understanding on atomicStore
cases, and I'll review the rest later.
A global comment: can we declare something like using VectorValue = TypedValue<VectorType>;
in global, so you can use it everywhere in the file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @lialan , I'd appreciate a bit more high-level explanation.
This patch expands the store into two (maybe atomic) rmw partial stores (for the first and the third byte), and a non-atomic single full byte store (for the second byte).
Could you define what "atomic" and "non-atomic" means in the context of this patch? Also, it's not clear to me what "new" is being introduced? Pseudo-code with "what we have today" vs "what this PR enables" would be appreciated.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
fe648aa
to
72bdea0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! I left some comments, please take a look.
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting there 😅
if (origElements % scale != 0) | ||
return failure(); | ||
auto origElements = valueToStore.getType().getNumElements(); | ||
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please double check, but I think that this definition of "aligned" is only covering point 1. from
llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Lines 1072 to 1079 in 5a90168
/// Verify that `subByteVecType` and `dstType` are aligned. Alignment | |
/// means that: | |
/// 1. The `dstType` element type is a multiple of the | |
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8 | |
/// is not supported). Let this multiple be `N`. | |
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a | |
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is | |
/// not supported). |
If you agree, I'll try to propose something - we should use consistent definition throughout this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please! That will make subsequent changes in the file easier, and I can do the update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I think that I have something that would help unify things and clarify the terminology a bit. But I'm worried that I discovered an inconsistency/bug, so need to double check.
Let me get back to you on Monday, I've run out of cycles for the week :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- [mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) #123526
- [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N) #123527
- [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N) #123528
- [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N) #123529
You should be able to re-use isSubByteVecFittable
from #123529.
There's 4 patches to make reviewing easier, but they are quite short and should be easy to re-base on top 🤞🏻
// Shortcut: conditions when subbyte emulated store at the front is not | ||
// needed: | ||
// 1. The source vector size (in bits) is a multiple of byte size. | ||
// 2. The address of the store is aligned to the emulated width boundary. | ||
// | ||
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not | ||
// need unaligned emulation because the store address is aligned and the | ||
// source is a whole byte. | ||
if (isAlignedEmulation && *foldedNumFrontPadElems == 0) { | ||
auto numElements = origElements / numSrcElemsPerDest; | ||
auto bitCast = rewriter.create<vector::BitCastOp>( | ||
loc, VectorType::get(numElements, newElementType), | ||
op.getValueToStore()); | ||
rewriter.replaceOpWithNewOp<vector::StoreOp>( | ||
op, bitCast.getResult(), memrefBase, | ||
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); | ||
return success(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I can tell, this block is not tested. Something like this will exercise it:
// Aligned store, hence full stores.
func.func @vector_store_i2_one_full_store_multiple_bytes(%arg0: vector<32xi2>) {
%alloc = memref.alloc() : memref<8xi8>
%0 = vector.bitcast %arg0 : vector<32xi2> to vector<8xi8>
%c0 = arith.constant 0 : index
vector.store %0, %alloc[%c0] : memref<8xi8>, vector<8xi8>
return
}
Feel free to re-use it. Also, I don't really like expressions like "shortcut". Instead, IMO, this is a special "case". To demonstrate what I have in mind:
bool emulationRequiresPartialStores = isAlignedEmulation && *foldedNumFrontPadElems;
If (!emulationRequiresPartialStores) {
// Basic case, storing full bytes.
// Your code here.
}
// Complex case, emulation requires partial stores.
Something along those lines 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block is actually copied from original code, so it handles aligned cases. And I feel like there is already test case for those, for example: https://github.com/lialan/llvm-project/blob/lialan/atomic_stores/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir#L420-L424
does that sound good enough for test coverage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks really cool! Just a few comments
mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Outdated
Show resolved
Hide resolved
rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru), | ||
emptyVector, linearizedInfo.intraDataOffset, origElements); | ||
} else if (isUnalignedEmulation) { | ||
rewriter, loc, cast<VectorValue>(passthru), emptyVector, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if enforcing all these castings to TypeValue
are worth it. Wouldn't checking the type within an assert be more efficient for the Release
version and cleaner? We may also end up checking the type too many times, depending on when the type check happens in ValueType...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are right, this is a slight opt that we can do. I have updated to use assert, except those functions that specifically need VectorValue
.
This patch enables unaligned, statically indexed storing of vectors with sub emulation width element types. To illustrate the mechanism, consider the example of storing vector<7xi2> into memref<3x7xi2>[1, 0]. In this case the linearized indices of those bits being overwritten are [14, 28), which are: * the last 2 bits of byte no.2 * byte no.3 * first 4 bits of byte no.4 Because memory accesses are in bytes, byte no.2 and no.4 in the above example are only being modified partially. In the case of multi-threading scenario, in order to avoid data contention, these two bytes must be handled atomically.
e6d83e9
to
d458e3b
Compare
Hi Alan, while I’d love to resolve #123630 before your change lands, it’s a nice-to-have rather than a blocker. Just thought I’d mention it, as that issue is progressing a bit slowly 😅 I’ll be OOO over the next week. If Diego approves in the meantime (as the other active reviewer), please don’t wait for me—feel free to land it. I’ll rebase my changes in due time. Thank you for pushing this forward 🙏🏻 |
@dcaballe Diego can you take another look at this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I guess we can land the renaming later.
Updates `emulatedVectorLoad` that was introduced in llvm#115922. Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and "container type". This only became clear after llvm#123526 in which the concepts of "emulated" and "container" types were introduced. This is an NFC change and simply updates the variable naming.
Updates `emulatedVectorLoad` that was introduced in #115922. Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and "container type". This only became clear after #123526 in which the concepts of "emulated" and "container" types were introduced. This is an NFC change and simply updates the variable naming.
) This patch is a followup of the previous one: #115922, It adds an option to turn on emitting non-atomic rmw code sequence instead of atomic rmw.
…125415) Updates `emulatedVectorLoad` that was introduced in llvm#115922. Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and "container type". This only became clear after llvm#123526 in which the concepts of "emulated" and "container" types were introduced. This is an NFC change and simply updates the variable naming.
…#124887) This patch is a followup of the previous one: llvm#115922, It adds an option to turn on emitting non-atomic rmw code sequence instead of atomic rmw.
This patch enables unaligned, statically indexed storing of vectors with sub emulation width element types.
To illustrate the mechanism, consider the example of storing vector<7xi2> into memref<3x7xi2>[1, 0].
In this case the linearized indices of those bits being overwritten are [14, 28), which are:
Because memory accesses are in bytes, byte no.2 and no.4 in the above example are only being modified partially.
In the case of multi-threading scenario, in order to avoid data contention, these two bytes must be handled atomically.