Skip to content

Commit

Permalink
Reuse input buffer in lowering to krnl (#2939)
Browse files Browse the repository at this point in the history
* first step

Signed-off-by: chentong319 <chentong@us.ibm.com>

* cpu

Signed-off-by: chentong319 <chentong@us.ibm.com>

* options

Signed-off-by: chentong319 <chentong@us.ibm.com>

* unify

Signed-off-by: chentong319 <chentong@us.ibm.com>

* simd

Signed-off-by: chentong319 <chentong@us.ibm.com>

* comments

Signed-off-by: chentong319 <chentong@us.ibm.com>

* lit test

Signed-off-by: chentong319 <chentong@us.ibm.com>

* fix test

Signed-off-by: chentong319 <chentong@us.ibm.com>

* format

Signed-off-by: chentong319 <chentong@us.ibm.com>

* response

Signed-off-by: chentong319 <chentong@us.ibm.com>

---------

Signed-off-by: chentong319 <chentong@us.ibm.com>
  • Loading branch information
chentong319 authored Sep 13, 2024
1 parent 02f45b0 commit 97d497f
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 9 deletions.
9 changes: 9 additions & 0 deletions src/Compiler/CompilerOptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both
std::vector<std::string> functionsToDecompose; // common for both
std::string opsForCall; // common for both
bool disableKrnlOpFusion; // common for both
bool enableKrnlBufferReuse; // common for both
bool disableMemRefPrefetch; // common for both
EmissionTargetType emissionTarget; // onnx-mlir only
bool invokeOnnxVersionConverter; // onnx-mlir only
Expand Down Expand Up @@ -212,6 +213,14 @@ static llvm::cl::opt<bool, true> disableKrnlOpFusionOpt(
llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> enableKrnlBufferReuseOpt(
"enable-krnl-buffer-reuse",
llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass"
"(default=false)\n"
"Set to 'true' if you want to enable buffer reuse."),
llvm::cl::location(enableKrnlBufferReuse), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirCommonOptions));

static llvm::cl::opt<bool, true> disableMemRefPrefetchOpt(
"disable-memref-prefetch",
llvm::cl::desc("disable generation of memref.prefetch (default=false)\n"
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/CompilerOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both
extern std::vector<std::string> functionsToDecompose; // common for both
extern std::string opsForCall; // common for both
extern bool disableKrnlOpFusion; // common for both
extern bool enableKrnlBufferReuse; // common for both
extern bool disableMemRefPrefetch; // common for both
extern EmissionTargetType emissionTarget; // onnx-mlir only
extern bool invokeOnnxVersionConverter; // onnx-mlir only
Expand Down
97 changes: 88 additions & 9 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,82 @@ using namespace mlir;

namespace onnx_mlir {

// Check the input, x, can be reused as the output buffer
bool isBufferReusable(Value x, MemRefType outputType) {
if (!x.hasOneUse())
return false;

Type xType = x.getType();
auto inputType = dyn_cast<ShapedType>(xType);
if (!inputType)
return false;
// Currently, only static shape could be reused.
// ToFix: use DimAnalysis to handle dynamic shape.
if (!hasStaticShape(inputType))
return false;
if (!hasStaticShape(outputType))
return false;

// Currently reuse requires that the shape has to be the same.
// ToFix: If the shape is not the same, memref.cast can be used.
if (getRank(inputType) != getRank(outputType))
return false;
for (int64_t i = 0; i < getRank(inputType); i++) {
if (inputType.getShape()[i] != outputType.getShape()[i])
return false;
}

// ToFix: The simd padding is not checked
// We did not record whether the memref is padded or not.
// The padding added to the memref the as an attribute, or not needed.
return true;
}

// Traverse the operands to find the candidate for buffer reuse.
// Return -1, if no candidate is found.
int whichBufferToReuse(ValueRange values, MemRefType outputType) {
for (size_t i = 0; i < values.size(); i++) {
if (isBufferReusable(values[i], outputType))
return i;
}
return -1;
}

// Allocate memref (as before) if no input buffer can be reused.
// Default VL=0 is used for non SIMD allocation
Value allocOrReuse(MemRefBuilder &create, Operation *op,
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
int64_t alignment, int64_t VL = 0);

Value allocOrReuse(MemRefBuilder &create, Operation *op,
ValueRange generatedOperands, MemRefType outputMemRefType, DimsExprRef dims,
int64_t alignment, int64_t VL) {

int indexToReuse = -1;
// By default, enableKrnlBufferReuse is false. Simply allocate a memref.
if (enableKrnlBufferReuse) {
// Be aware to use the op->getOperands() to check the number of uses.
// After buffer reuse, the number of uses of the transformed Value,
// generatedOperands, will increase.
indexToReuse = whichBufferToReuse(op->getOperands(), outputMemRefType);
}

if (indexToReuse != -1) {
int size = getSizeInBytes(outputMemRefType);
LLVM_DEBUG({
llvm::dbgs() << " malloc_size " << size << "\n";
op->dump();
});
return generatedOperands[indexToReuse];
} else {
if (VL == 0)
return create.alignedAlloc(outputMemRefType, dims, alignment);
else
return create.alignedAllocWithSimdPadding(
outputMemRefType, dims, VL, alignment);
}
}

// =============================================================================

/// Emit post-processing for variadic element-wise ops.
Expand Down Expand Up @@ -1323,14 +1399,14 @@ static LogicalResult getPartiallyFlattenedSimdCode(
IndexExprScope allocScope(create.vec, shapeHelper->getScope());
DimsExpr outputDims;
getIndexExprList<SymbolIndexExpr>(shapeHelper->getOutputDims(), outputDims);
// Alloc memory with padding for SIMD.
// Reuse the buffer from the input, or Alloc memory with padding for SIMD.
// For the moment, its ok to go here; if we truly have partial flattening of
// the simd code, then we only do it with static memref size that are
// multiples of VL * unrollVL, so there should be no padding anyway. This
// will change if we do partial flattening with non-multiple of VL *
// unrollVL.
Value alloc = create.mem.alignedAllocWithSimdPadding(
outputMemRefType, outputDims, VL, alignment);
Value alloc = allocOrReuse(
create.mem, op, operands, outputMemRefType, outputDims, alignment, VL);
// Create flat inputs in the last innerDinNum dims.
llvm::SmallVector<Value, 4> flatOperands;
for (Value oper : operands) {
Expand Down Expand Up @@ -1975,8 +2051,9 @@ struct ONNXElementwiseUnaryOpLowering
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation for the result of this operation.
Value alloc = create.mem.alignedAlloc(
outputMemRefType, shapeHelper.getOutputDims(), alignment);
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
shapeHelper.getOutputDims(), alignment);
;

// Only create krnl.iterate if one of the operands is not scalar tensor.
if (!isScalar) {
Expand Down Expand Up @@ -2156,8 +2233,9 @@ struct ONNXElementwiseBinaryOpLowering
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation and deallocation for the result of this operation.
Value alloc = create.mem.alignedAlloc(
outputMemRefType, shapeHelper.getOutputDims(), alignment);
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
shapeHelper.getOutputDims(), alignment);
;

// Only create krnl.iterate if one of the operands is not scalar tensor.
if (!isScalar) {
Expand Down Expand Up @@ -2331,8 +2409,9 @@ struct ONNXElementwiseVariadicOpLowering
outputMemRefType = opFusionHelper.getOutputType(outputMemRefType);

// Insert an allocation and deallocation for the result of this operation.
Value alloc = create.mem.alignedAlloc(
outputMemRefType, shapeHelper.getOutputDims(), alignment);
Value alloc = allocOrReuse(create.mem, op, operands, outputMemRefType,
shapeHelper.getOutputDims(), alignment);
;

// Only create krnl.iterate if one of the operands is not scalar tensor.
if (!isScalar) {
Expand Down
11 changes: 11 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/onnx_lowering_reuse.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: onnx-mlir-opt --disable-krnl-op-fusion=true --enable-krnl-buffer-reuse=true --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s

// -----
func.func @test_reuse(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>) -> tensor<1024xf32> {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<1024xf32>, tensor<1024xf32>) -> tensor<1024xf32>
%1 = "onnx.Sqrt"(%0) : (tensor<1024xf32>) -> tensor<1024xf32>
%2 = "onnx.Sqrt"(%1) : (tensor<1024xf32>) -> tensor<1024xf32>
return %2 : tensor<1024xf32>
}
// CHECK-LABEL: func.func @test_reuse
// CHECK-NOT: memref.alloc

0 comments on commit 97d497f

Please sign in to comment.