Skip to content

Commit

Permalink
[TorchToLinalg] perform rank0 elementwise computations outside linalg…
Browse files Browse the repository at this point in the history
… generic ops (llvm#3762)

This is motivated by the fact that shapes are stored as tensors in ONNX,
and IREE tries to perform tensor arithmetic on the device. This causes
unnecessary dispatches, and makes it harder for the compiler to reason
about shapes.

Here is a small snippet of torch-IR that is typical seen coming from
ONNX models:

```mlir
module {
  func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> {
    %int0 = torch.constant.int 0
    %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
    %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64>
    %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64>
    %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64>
    %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int
    %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool
    %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int
    %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int
    %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1>
    %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64>
    %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64>
    %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
    return %11 : !torch.vtensor<[],si64>
  }
}
```

Without the change in this PR, the result would be:

```mlir
#map = affine_map<() -> ()>
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = tensor.empty() : tensor<i1>
    %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor<i1>) -> tensor<i1>
    %7 = tensor.empty() : tensor<i64>
    %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor<i64>) -> tensor<i64>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor<i1>, tensor<i64>, tensor<i64>) outs(%7 : tensor<i64>) {
    ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64):
      %11 = arith.select %in, %in_1, %in_2 : i64
      linalg.yield %11 : i64
    } -> tensor<i64>
    return %10 : tensor<i64>
  }
}
```

With the change in this PR, we would instead get:

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @main_graph(%arg0: tensor<?x?x768xf32>, %arg1: tensor<?x?x768xf32>) -> tensor<i64> {
    %c0_i64 = arith.constant 0 : i64
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg1, %c0 : tensor<?x?x768xf32>
    %0 = arith.index_cast %dim : index to i64
    %1 = tensor.empty() : tensor<1xi64>
    %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor<i64>
    %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor<i64>) -> tensor<i64>
    %extracted = tensor.extract %2[] : tensor<i64>
    %3 = arith.cmpi eq, %extracted, %c0_i64 : i64
    %dim_0 = tensor.dim %arg0, %c0 : tensor<?x?x768xf32>
    %4 = arith.index_cast %dim_0 : index to i64
    %5 = arith.select %3, %4, %extracted : i64
    %6 = tensor.empty() : tensor<i64>
    %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor<i64>) -> tensor<i64>
    return %7 : tensor<i64>
  }
}
```

Some related issues for context:
1. <iree-org/iree#18677>
2. <iree-org/iree#18631>
  • Loading branch information
zjgarvey authored Oct 4, 2024
1 parent f08bfc4 commit 6e8c7be
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
19 changes: 19 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,25 @@ class ConvertElementwiseOp : public ConversionPattern {
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
bool isScalarOp = resultType.getShape().size() == 0;
if (isScalarOp) {
// for elementwise ops that are actually rank0 scalar computations,
// perform the payload outside a linalg generic op.
SmallVector<Value> payloadArgs;
for (auto t : tensorOperands) {
payloadArgs.push_back(rewriter.create<tensor::ExtractOp>(loc, t));
}
Value scalarResult = createLinalgPayloadCalculationForElementwiseOp(
rewriter, loc, getTypeConverter(), payloadArgs, op, operands);
if (!scalarResult)
return rewriter.notifyMatchFailure(
op, "Failed to create payload for scalar elementwise op");
Value rank0Result =
createInitTensor(rewriter, loc, ValueRange{},
resultType.getElementType(), scalarResult);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rank0Result);
return success();
}
bool hadErrorCreatingPayload = false;
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(),
Expand Down
12 changes: 5 additions & 7 deletions test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
// CHECK-LABEL: func.func @elementwise$unary(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32
// CHECK: linalg.yield %[[TANH]] : f32
// CHECK: } -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor<f32>
// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor<f32>) -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
// CHECK: }
Expand Down

0 comments on commit 6e8c7be

Please sign in to comment.