Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Implement emulation of static indexing subbyte type vector stores #115922

Merged
merged 9 commits into from
Jan 29, 2025

Conversation

lialan
Copy link
Member

@lialan lialan commented Nov 12, 2024

This patch enables unaligned, statically indexed storing of vectors with sub emulation width element types.

To illustrate the mechanism, consider the example of storing vector<7xi2> into memref<3x7xi2>[1, 0].
In this case the linearized indices of those bits being overwritten are [14, 28), which are:

  • the last 2 bits of byte no.2
  • byte no.3
  • first 4 bits of byte no.4

Because memory accesses are in bytes, byte no.2 and no.4 in the above example are only being modified partially.
In the case of multi-threading scenario, in order to avoid data contention, these two bytes must be handled atomically.

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

Changes

This patch enables unaligned, statically indexed storing of vectors with subbyte element types.

To ensure atomicity, boundary bytes in the front and at the back are being treated separately with atomic stores to avoid thread contention, while those bytes (if any) free from race conditions will use non-atomic stores.

To illustrate the mechanism, consider the example of storing vector&lt;7xi2&gt; into memref&lt;3x7xi2&gt;[1, 0]. In this case the linearized indices of those bits being overwritten are [14, 28), which are:

  • the last 2 bits of byte no.2
  • byte no.3
  • first 4 bits of byte no.4

This patch expands the store into two atomic partial stores (for the first and the third byte), and a non-atomic single full byte store (for the second byte).


Patch is 21.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115922.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+166-33)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+135)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7578aadee23a6e..031eb2d9c443b8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -143,19 +143,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
 /// emitting `vector.extract_strided_slice`.
 static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
-                                        VectorType extractType, Value source,
-                                        int64_t frontOffset,
+                                        Value source, int64_t frontOffset,
                                         int64_t subvecSize) {
-  auto vectorType = cast<VectorType>(source.getType());
-  assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
-         "expected 1-D source and destination types");
-  (void)vectorType;
+  auto vectorType = llvm::cast<VectorType>(source.getType());
+  assert(vectorType.getRank() == 1 && "expected 1-D source types");
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
+
+  auto resultVectorType =
+      VectorType::get({subvecSize}, vectorType.getElementType());
   return rewriter
-      .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
-                                             sizes, strides)
+      .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+                                             offsets, sizes, strides)
       ->getResult(0);
 }
 
@@ -164,12 +164,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
 /// `vector.insert_strided_slice`.
 static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
                                        Value src, Value dest, int64_t offset) {
-  auto srcType = cast<VectorType>(src.getType());
-  auto destType = cast<VectorType>(dest.getType());
+  [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
+  [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
   assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
          "expected source and dest to be vector type");
-  (void)srcType;
-  (void)destType;
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -236,6 +234,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
+                           Value memref, Value index, Value value) {
+  auto originType = dyn_cast<VectorType>(value.getType());
+  auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
+  auto scale = memrefElemType.getIntOrFloatBitWidth() /
+               originType.getElementType().getIntOrFloatBitWidth();
+  auto storeType =
+      VectorType::get({originType.getNumElements() / scale}, memrefElemType);
+  auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
+  rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
+}
+
+/// atomically store a subbyte-sized value to memory, with a mask.
+static Value atomicStore(OpBuilder &rewriter, Location loc,
+                         Value emulatedMemref, Value emulatedIndex,
+                         TypedValue<VectorType> value, Value mask,
+                         int64_t scale) {
+  auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
+      loc, emulatedMemref, ValueRange{emulatedIndex});
+  OpBuilder builder =
+      OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
+  Value origValue = atomicOp.getCurrentValue();
+
+  // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
+  auto oneVectorType = VectorType::get({1}, origValue.getType());
+  auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+                                                         ValueRange{origValue});
+  auto vectorBitCast =
+      builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+  auto select =
+      builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+  auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+  auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+  builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+  return atomicOp;
+}
+
+// Extract a slice of a vector, and insert it into a byte vector.
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+                                  Location loc, TypedValue<VectorType> vector,
+                                  int64_t sliceOffset, int64_t sliceNumElements,
+                                  int64_t byteOffset) {
+  auto vectorElementType = vector.getType().getElementType();
+  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+         "vector element must be a valid sub-byte type");
+  auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+      loc, VectorType::get({scale}, vectorElementType),
+      rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+                                              sliceOffset, sliceNumElements);
+  auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
+                                            emptyByteVector, byteOffset);
+  return inserted;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -251,7 +306,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
-    Type oldElementType = op.getValueToStore().getType().getElementType();
+    auto valueToStore = op.getValueToStore();
+    Type oldElementType = valueToStore.getType().getElementType();
     Type newElementType = convertedType.getElementType();
     int srcBits = oldElementType.getIntOrFloatBitWidth();
     int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -275,15 +331,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
     // vector<4xi8>
 
-    auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    auto origElements = valueToStore.getType().getNumElements();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -291,14 +347,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
-    auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements, newElementType),
-        op.getValueToStore());
+    auto foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
+    Value emulatedMemref = adaptor.getBase();
+    // the index into the target memref we are storing to
+    Value currentDestIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
+    // the index into the source vector we are currently processing
+    auto currentSourceIndex = 0;
+
+    // 1. atomic store for the first byte
+    auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
+    if (frontAtomicStoreElem != 0) {
+      auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
+      if (*foldedIntraVectorOffset + origElements < scale) {
+        std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
+                    origElements, true);
+        frontAtomicStoreElem = origElements;
+      } else {
+        std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
+                    *foldedIntraVectorOffset, true);
+      }
+      auto frontMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
+
+      currentSourceIndex = scale - (*foldedIntraVectorOffset);
+      auto value = extractSliceIntoByte(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
+          frontAtomicStoreElem, *foldedIntraVectorOffset);
+
+      atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                  cast<TypedValue<VectorType>>(value), frontMask.getResult(),
+                  scale);
+
+      currentDestIndex = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+    }
+
+    if (currentSourceIndex >= origElements) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    // 2. non-atomic store
+    int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
+    int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
+    if (nonAtomicStoreSize != 0) {
+      auto nonAtomicStorePart = staticallyExtractSubvector(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+          currentSourceIndex, numNonAtomicElements);
+
+      nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                     nonAtomicStorePart);
+
+      currentSourceIndex += numNonAtomicElements;
+      currentDestIndex = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), currentDestIndex,
+          rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
+    }
+
+    // 3. atomic store for the last byte
+    auto remainingElements = origElements - currentSourceIndex;
+    if (remainingElements != 0) {
+      auto atomicStorePart = extractSliceIntoByte(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+          currentSourceIndex, remainingElements, 0);
+
+      // back mask
+      auto maskValues = llvm::SmallVector<bool>(scale, 0);
+      std::fill_n(maskValues.begin(), remainingElements, 1);
+      auto backMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(atomicMaskType, maskValues));
+
+      atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                  cast<TypedValue<VectorType>>(atomicStorePart),
+                  backMask.getResult(), scale);
+    }
 
-    rewriter.replaceOpWithNewOp<vector::StoreOp>(
-        op, bitCast.getResult(), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+    rewriter.eraseOp(op);
     return success();
   }
 };
@@ -496,9 +632,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -652,9 +787,8 @@ struct ConvertVectorMaskedLoad final
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
           op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -732,9 +866,8 @@ struct ConvertVectorTransferRead final
                                            linearizedInfo.intraDataOffset,
                                            origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..3b1f5b2e160fe0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,138 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
 // CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
 // CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_store_i2_const(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
+// CHECK: func @vector_store_i2_const(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// atomic store of the first byte
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// atomic store of the second byte
+// CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
+
+// -----
+
+func.func @vector_store_i8_2(%arg0: vector<7xi2>) {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores and 1 non-atomic store
+
+// CHECK: func @vector_store_i8_2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// first atomic store
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// non atomic store part
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// second atomic store
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8    
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+    %0 = memref.alloc() : memref<4x1xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+    return
+}
+
+// in this example, only emit 1 ato...
[truncated]

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if the start of the store is aligned? Would it still generate atomic stores?

I havent looked into the details yet, but we only need an atomic store if we cannot guarantee that there are no competing stores. This information we canno thave at this level, but a caller might have this information. I think it might be better to allow for a caller to indicate that the atomic stores are not needed. The default can be that you do the atomic stores.

@lialan lialan force-pushed the lialan/atomic_stores branch 2 times, most recently from bfdbed2 to edfe3d4 Compare November 12, 2024 22:18
@lialan
Copy link
Member Author

lialan commented Nov 12, 2024

What would happen if the start of the store is aligned? Would it still generate atomic stores?

It will be operation as usual, the atomic and the extra handling only happen when needed.

I havent looked into the details yet, but we only need an atomic store if we cannot guarantee that there are no competing stores. This information we canno thave at this level, but a caller might have this information. I think it might be better to allow for a caller to indicate that the atomic stores are not needed. The default can be that you do the atomic stores.

Feel like in such a non-competing case we can do strength reduction and use masked store instead of atomic for the unaligned parts (the beginning and the end byte if it is unaligned).

@MaheshRavishankar
Copy link
Contributor

Feel like in such a non-competing case we can do strength reduction and use masked store instead of atomic for the unaligned parts (the beginning and the end byte if it is unaligned).

You cant always tell from the program analysis that the atomic isnt needed. This depends on the caller's context. For example if the caller knows that, even dynamically, the stores never overlap. It is hard to recover that information at this late stage. So I think it would be good to have an option where atomics are avoided (cause it is expensive) and let caller opt-in/out?

@lialan lialan force-pushed the lialan/atomic_stores branch from edfe3d4 to 9b81a3f Compare November 14, 2024 17:11
@lialan lialan marked this pull request as draft November 14, 2024 17:12
@lialan lialan force-pushed the lialan/atomic_stores branch 2 times, most recently from aec91d0 to b46f01e Compare November 14, 2024 20:11
@lialan lialan marked this pull request as ready for review November 14, 2024 20:30
@lialan lialan force-pushed the lialan/atomic_stores branch from b46f01e to 5069134 Compare November 19, 2024 17:44
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lialan for the patch. Here is the first round of review comments. I haven't finished the review of the three steps yet, but the idea looks good to me. I mostly want to clarify my understanding on atomicStore cases, and I'll review the rest later.


A global comment: can we declare something like using VectorValue = TypedValue<VectorType>; in global, so you can use it everywhere in the file?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lialan , I'd appreciate a bit more high-level explanation.

This patch expands the store into two (maybe atomic) rmw partial stores (for the first and the third byte), and a non-atomic single full byte store (for the second byte).

Could you define what "atomic" and "non-atomic" means in the context of this patch? Also, it's not clear to me what "new" is being introduced? Pseudo-code with "what we have today" vs "what this PR enables" would be appreciated.

@lialan lialan force-pushed the lialan/atomic_stores branch 2 times, most recently from fe648aa to 72bdea0 Compare November 26, 2024 07:35
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I left some comments, please take a look.

@lialan lialan requested a review from banach-space January 15, 2025 07:44
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting there 😅

if (origElements % scale != 0)
return failure();
auto origElements = valueToStore.getType().getNumElements();
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double check, but I think that this definition of "aligned" is only covering point 1. from

/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
/// means that:
/// 1. The `dstType` element type is a multiple of the
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
/// is not supported). Let this multiple be `N`.
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
/// not supported).

If you agree, I'll try to propose something - we should use consistent definition throughout this file.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please! That will make subsequent changes in the file easier, and I can do the update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I think that I have something that would help unify things and clarify the terminology a bit. But I'm worried that I discovered an inconsistency/bug, so need to double check.

Let me get back to you on Monday, I've run out of cycles for the week :(

Copy link
Contributor

@banach-space banach-space Jan 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to re-use isSubByteVecFittable from #123529.

There's 4 patches to make reviewing easier, but they are quite short and should be easy to re-base on top 🤞🏻

Comment on lines 473 to 492
// Shortcut: conditions when subbyte emulated store at the front is not
// needed:
// 1. The source vector size (in bits) is a multiple of byte size.
// 2. The address of the store is aligned to the emulated width boundary.
//
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
if (isAlignedEmulation && *foldedNumFrontPadElems == 0) {
auto numElements = origElements / numSrcElemsPerDest;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, bitCast.getResult(), memrefBase,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return success();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell, this block is not tested. Something like this will exercise it:

  // Aligned store, hence full stores.
  
  func.func @vector_store_i2_one_full_store_multiple_bytes(%arg0: vector<32xi2>) {
    %alloc = memref.alloc() : memref<8xi8>
    %0 = vector.bitcast %arg0 : vector<32xi2> to vector<8xi8>
    %c0 = arith.constant 0 : index
    vector.store %0, %alloc[%c0] : memref<8xi8>, vector<8xi8>
    return
  }

Feel free to re-use it. Also, I don't really like expressions like "shortcut". Instead, IMO, this is a special "case". To demonstrate what I have in mind:

bool emulationRequiresPartialStores = isAlignedEmulation && *foldedNumFrontPadElems;
If (!emulationRequiresPartialStores) {
   // Basic case, storing full bytes.
   // Your code here.
}

// Complex case, emulation requires partial stores.

Something along those lines 😅

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is actually copied from original code, so it handles aligned cases. And I feel like there is already test case for those, for example: https://github.com/lialan/llvm-project/blob/lialan/atomic_stores/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir#L420-L424

does that sound good enough for test coverage?

@lialan lialan requested a review from banach-space January 17, 2025 05:46
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks really cool! Just a few comments

rewriter, loc, dyn_cast<TypedValue<VectorType>>(passthru),
emptyVector, linearizedInfo.intraDataOffset, origElements);
} else if (isUnalignedEmulation) {
rewriter, loc, cast<VectorValue>(passthru), emptyVector,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if enforcing all these castings to TypeValue are worth it. Wouldn't checking the type within an assert be more efficient for the Release version and cleaner? We may also end up checking the type too many times, depending on when the type check happens in ValueType...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right, this is a slight opt that we can do. I have updated to use assert, except those functions that specifically need VectorValue.

lialan and others added 8 commits January 24, 2025 10:54
This patch enables unaligned, statically indexed storing of vectors with sub emulation width element types.

To illustrate the mechanism, consider the example of storing vector<7xi2> into memref<3x7xi2>[1, 0].
In this case the linearized indices of those bits being overwritten are [14, 28), which are:

* the last 2 bits of byte no.2
* byte no.3
* first 4 bits of byte no.4

Because memory accesses are in bytes, byte no.2 and no.4 in the above example are only being modified partially.
In the case of multi-threading scenario, in order to avoid data contention, these two bytes must be handled atomically.
@lialan lialan force-pushed the lialan/atomic_stores branch from e6d83e9 to d458e3b Compare January 24, 2025 11:13
@lialan lialan requested a review from dcaballe January 24, 2025 11:14
@banach-space
Copy link
Contributor

Hi Alan, while I’d love to resolve #123630 before your change lands, it’s a nice-to-have rather than a blocker. Just thought I’d mention it, as that issue is progressing a bit slowly 😅

I’ll be OOO over the next week. If Diego approves in the meantime (as the other active reviewer), please don’t wait for me—feel free to land it. I’ll rebase my changes in due time.

Thank you for pushing this forward 🙏🏻

@lialan
Copy link
Member Author

lialan commented Jan 28, 2025

@dcaballe Diego can you take another look at this?

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I guess we can land the renaming later.

@lialan lialan merged commit cdced8e into llvm:main Jan 29, 2025
8 checks passed
@lialan lialan deleted the lialan/atomic_stores branch January 29, 2025 04:28
banach-space added a commit to banach-space/llvm-project that referenced this pull request Feb 2, 2025
Updates `emulatedVectorLoad` that was introduced in llvm#115922.
Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and
"container type". This only became clear after llvm#123526 in which the
concepts of "emulated" and "container" types were introduced.

This is an NFC change and simply updates the variable naming.
banach-space added a commit that referenced this pull request Feb 3, 2025
Updates `emulatedVectorLoad` that was introduced in #115922.
Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and
"container type". This only became clear after #123526 in which the
concepts of "emulated" and "container" types were introduced.

This is an NFC change and simply updates the variable naming.
lialan added a commit that referenced this pull request Feb 6, 2025
)

This patch is a followup of the previous one: #115922, It adds an option
to turn on emitting non-atomic rmw code sequence instead of atomic rmw.
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…125415)

Updates `emulatedVectorLoad` that was introduced in llvm#115922.
Specifically, ATM `emulatedVectorLoad` mixes "emulated type" and
"container type". This only became clear after llvm#123526 in which the
concepts of "emulated" and "container" types were introduced.

This is an NFC change and simply updates the variable naming.
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…#124887)

This patch is a followup of the previous one: llvm#115922, It adds an option
to turn on emitting non-atomic rmw code sequence instead of atomic rmw.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants